Spaces:
Runtime error
Runtime error
Delete video_to_video
Browse files- video_to_video/__init__.py +0 -0
- video_to_video/__pycache__/__init__.cpython-39.pyc +0 -0
- video_to_video/__pycache__/video_to_video_model.cpython-39.pyc +0 -0
- video_to_video/diffusion/__init__.py +0 -0
- video_to_video/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
- video_to_video/diffusion/__pycache__/diffusion_sdedit.cpython-39.pyc +0 -0
- video_to_video/diffusion/__pycache__/schedules_sdedit.cpython-39.pyc +0 -0
- video_to_video/diffusion/__pycache__/solvers_sdedit.cpython-39.pyc +0 -0
- video_to_video/diffusion/diffusion_sdedit.py +0 -443
- video_to_video/diffusion/schedules_sdedit.py +0 -85
- video_to_video/diffusion/solvers_sdedit.py +0 -204
- video_to_video/modules/__init__.py +0 -3
- video_to_video/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/embedder.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/t5.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/unet_v2v.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/unet_v2v_LocalConv.cpython-39.pyc +0 -0
- video_to_video/modules/__pycache__/unet_v2v_deform.cpython-39.pyc +0 -0
- video_to_video/modules/embedder.py +0 -75
- video_to_video/modules/t5.py +0 -335
- video_to_video/modules/unet_v2v.py +0 -2332
- video_to_video/utils/__init__.py +0 -0
- video_to_video/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- video_to_video/utils/__pycache__/config.cpython-39.pyc +0 -0
- video_to_video/utils/__pycache__/logger.cpython-39.pyc +0 -0
- video_to_video/utils/__pycache__/seed.cpython-39.pyc +0 -0
- video_to_video/utils/config.py +0 -169
- video_to_video/utils/logger.py +0 -94
- video_to_video/utils/seed.py +0 -14
- video_to_video/video_to_video_model.py +0 -237
video_to_video/__init__.py
DELETED
|
File without changes
|
video_to_video/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (153 Bytes)
|
|
|
video_to_video/__pycache__/video_to_video_model.cpython-39.pyc
DELETED
|
Binary file (6.97 kB)
|
|
|
video_to_video/diffusion/__init__.py
DELETED
|
File without changes
|
video_to_video/diffusion/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (163 Bytes)
|
|
|
video_to_video/diffusion/__pycache__/diffusion_sdedit.cpython-39.pyc
DELETED
|
Binary file (10.4 kB)
|
|
|
video_to_video/diffusion/__pycache__/schedules_sdedit.cpython-39.pyc
DELETED
|
Binary file (2.68 kB)
|
|
|
video_to_video/diffusion/__pycache__/solvers_sdedit.cpython-39.pyc
DELETED
|
Binary file (6.18 kB)
|
|
|
video_to_video/diffusion/diffusion_sdedit.py
DELETED
|
@@ -1,443 +0,0 @@
|
|
| 1 |
-
import random
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
from .schedules_sdedit import karras_schedule
|
| 6 |
-
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
|
| 7 |
-
|
| 8 |
-
from video_to_video.utils.logger import get_logger
|
| 9 |
-
|
| 10 |
-
logger = get_logger()
|
| 11 |
-
|
| 12 |
-
__all__ = ['GaussianDiffusion']
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def _i(tensor, t, x):
|
| 16 |
-
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
| 17 |
-
return tensor[t.to(tensor.device)].view(shape).to(x.device)
|
| 18 |
-
|
| 19 |
-
class GaussianDiffusion(object):
|
| 20 |
-
|
| 21 |
-
def __init__(self, sigmas):
|
| 22 |
-
self.sigmas = sigmas
|
| 23 |
-
self.alphas = torch.sqrt(1 - sigmas**2)
|
| 24 |
-
self.num_timesteps = len(sigmas)
|
| 25 |
-
|
| 26 |
-
def diffuse(self, x0, t, noise=None):
|
| 27 |
-
noise = torch.randn_like(x0) if noise is None else noise
|
| 28 |
-
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
|
| 29 |
-
|
| 30 |
-
return xt
|
| 31 |
-
|
| 32 |
-
def get_velocity(self, x0, xt, t):
|
| 33 |
-
sigmas = _i(self.sigmas, t, xt)
|
| 34 |
-
alphas = _i(self.alphas, t, xt)
|
| 35 |
-
velocity = (alphas * xt - x0) / sigmas
|
| 36 |
-
return velocity
|
| 37 |
-
|
| 38 |
-
def get_x0(self, v, xt, t):
|
| 39 |
-
sigmas = _i(self.sigmas, t, xt)
|
| 40 |
-
alphas = _i(self.alphas, t, xt)
|
| 41 |
-
x0 = alphas * xt - sigmas * v
|
| 42 |
-
return x0
|
| 43 |
-
|
| 44 |
-
def denoise(self,
|
| 45 |
-
xt,
|
| 46 |
-
t,
|
| 47 |
-
s,
|
| 48 |
-
model,
|
| 49 |
-
model_kwargs={},
|
| 50 |
-
guide_scale=None,
|
| 51 |
-
guide_rescale=None,
|
| 52 |
-
clamp=None,
|
| 53 |
-
percentile=None,
|
| 54 |
-
variant_info=None,):
|
| 55 |
-
s = t - 1 if s is None else s
|
| 56 |
-
|
| 57 |
-
# hyperparams
|
| 58 |
-
sigmas = _i(self.sigmas, t, xt)
|
| 59 |
-
alphas = _i(self.alphas, t, xt)
|
| 60 |
-
alphas_s = _i(self.alphas, s.clamp(0), xt)
|
| 61 |
-
alphas_s[s < 0] = 1.
|
| 62 |
-
sigmas_s = torch.sqrt(1 - alphas_s**2)
|
| 63 |
-
|
| 64 |
-
# precompute variables
|
| 65 |
-
betas = 1 - (alphas / alphas_s)**2
|
| 66 |
-
coef1 = betas * alphas_s / sigmas**2
|
| 67 |
-
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
|
| 68 |
-
var = betas * (sigmas_s / sigmas)**2
|
| 69 |
-
log_var = torch.log(var).clamp_(-20, 20)
|
| 70 |
-
|
| 71 |
-
# prediction
|
| 72 |
-
if guide_scale is None:
|
| 73 |
-
assert isinstance(model_kwargs, dict)
|
| 74 |
-
out = model(xt, t=t, **model_kwargs)
|
| 75 |
-
else:
|
| 76 |
-
# classifier-free guidance
|
| 77 |
-
assert isinstance(model_kwargs, list)
|
| 78 |
-
if len(model_kwargs) > 3:
|
| 79 |
-
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
|
| 80 |
-
else:
|
| 81 |
-
y_out = model(xt, t=t, **model_kwargs[0], **model_kwargs[2], variant_info=variant_info)
|
| 82 |
-
if guide_scale == 1.:
|
| 83 |
-
out = y_out
|
| 84 |
-
else:
|
| 85 |
-
if len(model_kwargs) > 3:
|
| 86 |
-
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], **model_kwargs[3], **model_kwargs[4], **model_kwargs[5])
|
| 87 |
-
else:
|
| 88 |
-
u_out = model(xt, t=t, **model_kwargs[1], **model_kwargs[2], variant_info=variant_info)
|
| 89 |
-
out = u_out + guide_scale * (y_out - u_out)
|
| 90 |
-
|
| 91 |
-
if guide_rescale is not None:
|
| 92 |
-
assert guide_rescale >= 0 and guide_rescale <= 1
|
| 93 |
-
ratio = (
|
| 94 |
-
y_out.flatten(1).std(dim=1) / # noqa
|
| 95 |
-
(out.flatten(1).std(dim=1) + 1e-12)
|
| 96 |
-
).view((-1, ) + (1, ) * (y_out.ndim - 1))
|
| 97 |
-
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
|
| 98 |
-
|
| 99 |
-
x0 = alphas * xt - sigmas * out
|
| 100 |
-
|
| 101 |
-
# restrict the range of x0
|
| 102 |
-
if percentile is not None:
|
| 103 |
-
assert percentile > 0 and percentile <= 1
|
| 104 |
-
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
|
| 105 |
-
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
|
| 106 |
-
x0 = torch.min(s, torch.max(-s, x0)) / s
|
| 107 |
-
elif clamp is not None:
|
| 108 |
-
x0 = x0.clamp(-clamp, clamp)
|
| 109 |
-
|
| 110 |
-
# recompute eps using the restricted x0
|
| 111 |
-
eps = (xt - alphas * x0) / sigmas
|
| 112 |
-
|
| 113 |
-
# compute mu (mean of posterior distribution) using the restricted x0
|
| 114 |
-
mu = coef1 * x0 + coef2 * xt
|
| 115 |
-
return mu, var, log_var, x0, eps
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
@torch.no_grad()
|
| 119 |
-
def sample(self,
|
| 120 |
-
noise,
|
| 121 |
-
model,
|
| 122 |
-
model_kwargs={},
|
| 123 |
-
condition_fn=None,
|
| 124 |
-
guide_scale=None,
|
| 125 |
-
guide_rescale=None,
|
| 126 |
-
clamp=None,
|
| 127 |
-
percentile=None,
|
| 128 |
-
solver='euler_a',
|
| 129 |
-
solver_mode='fast',
|
| 130 |
-
steps=20,
|
| 131 |
-
t_max=None,
|
| 132 |
-
t_min=None,
|
| 133 |
-
discretization=None,
|
| 134 |
-
discard_penultimate_step=None,
|
| 135 |
-
return_intermediate=None,
|
| 136 |
-
show_progress=False,
|
| 137 |
-
seed=-1,
|
| 138 |
-
chunk_inds=None,
|
| 139 |
-
**kwargs):
|
| 140 |
-
# sanity check
|
| 141 |
-
assert isinstance(steps, (int, torch.LongTensor))
|
| 142 |
-
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
|
| 143 |
-
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
|
| 144 |
-
assert discretization in (None, 'leading', 'linspace', 'trailing')
|
| 145 |
-
assert discard_penultimate_step in (None, True, False)
|
| 146 |
-
assert return_intermediate in (None, 'x0', 'xt')
|
| 147 |
-
|
| 148 |
-
# function of diffusion solver
|
| 149 |
-
solver_fn = {
|
| 150 |
-
'heun': sample_heun,
|
| 151 |
-
'dpmpp_2m_sde': sample_dpmpp_2m_sde
|
| 152 |
-
}[solver]
|
| 153 |
-
|
| 154 |
-
# options
|
| 155 |
-
schedule = 'karras' if 'karras' in solver else None
|
| 156 |
-
discretization = discretization or 'linspace'
|
| 157 |
-
seed = seed if seed >= 0 else random.randint(0, 2**31)
|
| 158 |
-
if isinstance(steps, torch.LongTensor):
|
| 159 |
-
discard_penultimate_step = False
|
| 160 |
-
if discard_penultimate_step is None:
|
| 161 |
-
discard_penultimate_step = True if solver in (
|
| 162 |
-
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
|
| 163 |
-
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
|
| 164 |
-
|
| 165 |
-
# function for denoising xt to get x0
|
| 166 |
-
intermediates = []
|
| 167 |
-
|
| 168 |
-
def model_fn(xt, sigma):
|
| 169 |
-
# denoising
|
| 170 |
-
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
| 171 |
-
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
|
| 172 |
-
guide_rescale, clamp, percentile)[-2]
|
| 173 |
-
|
| 174 |
-
# collect intermediate outputs
|
| 175 |
-
if return_intermediate == 'xt':
|
| 176 |
-
intermediates.append(xt)
|
| 177 |
-
elif return_intermediate == 'x0':
|
| 178 |
-
intermediates.append(x0)
|
| 179 |
-
return x0
|
| 180 |
-
|
| 181 |
-
mask_cond = model_kwargs[3]['mask_cond']
|
| 182 |
-
def model_chunk_fn(xt, sigma):
|
| 183 |
-
# denoising
|
| 184 |
-
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
| 185 |
-
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
|
| 186 |
-
cut_f_ind = O_LEN//2
|
| 187 |
-
|
| 188 |
-
results_list = []
|
| 189 |
-
for i in range(len(chunk_inds)):
|
| 190 |
-
ind_start, ind_end = chunk_inds[i]
|
| 191 |
-
xt_chunk = xt[:,:,ind_start:ind_end].clone()
|
| 192 |
-
cur_f = xt_chunk.size(2)
|
| 193 |
-
model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
|
| 194 |
-
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
|
| 195 |
-
guide_rescale, clamp, percentile)[-2]
|
| 196 |
-
if i == 0:
|
| 197 |
-
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
|
| 198 |
-
elif i == len(chunk_inds)-1:
|
| 199 |
-
results_list.append(x0_chunk[:,:,cut_f_ind:])
|
| 200 |
-
else:
|
| 201 |
-
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
|
| 202 |
-
x0 = torch.concat(results_list, dim=2)
|
| 203 |
-
torch.cuda.empty_cache()
|
| 204 |
-
return x0
|
| 205 |
-
|
| 206 |
-
# get timesteps
|
| 207 |
-
if isinstance(steps, int):
|
| 208 |
-
steps += 1 if discard_penultimate_step else 0
|
| 209 |
-
t_max = self.num_timesteps - 1 if t_max is None else t_max
|
| 210 |
-
t_min = 0 if t_min is None else t_min
|
| 211 |
-
|
| 212 |
-
# discretize timesteps
|
| 213 |
-
if discretization == 'leading':
|
| 214 |
-
steps = torch.arange(t_min, t_max + 1,
|
| 215 |
-
(t_max - t_min + 1) / steps).flip(0)
|
| 216 |
-
elif discretization == 'linspace':
|
| 217 |
-
steps = torch.linspace(t_max, t_min, steps)
|
| 218 |
-
elif discretization == 'trailing':
|
| 219 |
-
steps = torch.arange(t_max, t_min - 1,
|
| 220 |
-
-((t_max - t_min + 1) / steps))
|
| 221 |
-
if solver_mode == 'fast':
|
| 222 |
-
t_mid = 500
|
| 223 |
-
steps1 = torch.arange(t_max, t_mid - 1,
|
| 224 |
-
-((t_max - t_mid + 1) / 4))
|
| 225 |
-
steps2 = torch.arange(t_mid, t_min - 1,
|
| 226 |
-
-((t_mid - t_min + 1) / 11))
|
| 227 |
-
steps = torch.concat([steps1, steps2])
|
| 228 |
-
else:
|
| 229 |
-
raise NotImplementedError(
|
| 230 |
-
f'{discretization} discretization not implemented')
|
| 231 |
-
steps = steps.clamp_(t_min, t_max)
|
| 232 |
-
steps = torch.as_tensor(
|
| 233 |
-
steps, dtype=torch.float32, device=noise.device)
|
| 234 |
-
|
| 235 |
-
# get sigmas
|
| 236 |
-
sigmas = self._t_to_sigma(steps)
|
| 237 |
-
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
| 238 |
-
if schedule == 'karras':
|
| 239 |
-
if sigmas[0] == float('inf'):
|
| 240 |
-
sigmas = karras_schedule(
|
| 241 |
-
n=len(steps) - 1,
|
| 242 |
-
sigma_min=sigmas[sigmas > 0].min().item(),
|
| 243 |
-
sigma_max=sigmas[sigmas < float('inf')].max().item(),
|
| 244 |
-
rho=7.).to(sigmas)
|
| 245 |
-
sigmas = torch.cat([
|
| 246 |
-
sigmas.new_tensor([float('inf')]), sigmas,
|
| 247 |
-
sigmas.new_zeros([1])
|
| 248 |
-
])
|
| 249 |
-
else:
|
| 250 |
-
sigmas = karras_schedule(
|
| 251 |
-
n=len(steps),
|
| 252 |
-
sigma_min=sigmas[sigmas > 0].min().item(),
|
| 253 |
-
sigma_max=sigmas.max().item(),
|
| 254 |
-
rho=7.).to(sigmas)
|
| 255 |
-
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
| 256 |
-
if discard_penultimate_step:
|
| 257 |
-
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
| 258 |
-
|
| 259 |
-
fn = model_chunk_fn if chunk_inds is not None else model_fn
|
| 260 |
-
x0 = solver_fn(
|
| 261 |
-
noise, fn, sigmas, show_progress=show_progress, **kwargs)
|
| 262 |
-
return (x0, intermediates) if return_intermediate is not None else x0
|
| 263 |
-
|
| 264 |
-
@torch.no_grad()
|
| 265 |
-
def sample_sr(self,
|
| 266 |
-
noise,
|
| 267 |
-
model,
|
| 268 |
-
model_kwargs={},
|
| 269 |
-
condition_fn=None,
|
| 270 |
-
guide_scale=None,
|
| 271 |
-
guide_rescale=None,
|
| 272 |
-
clamp=None,
|
| 273 |
-
percentile=None,
|
| 274 |
-
solver='euler_a',
|
| 275 |
-
solver_mode='fast',
|
| 276 |
-
steps=20,
|
| 277 |
-
t_max=None,
|
| 278 |
-
t_min=None,
|
| 279 |
-
discretization=None,
|
| 280 |
-
discard_penultimate_step=None,
|
| 281 |
-
return_intermediate=None,
|
| 282 |
-
show_progress=False,
|
| 283 |
-
seed=-1,
|
| 284 |
-
chunk_inds=None,
|
| 285 |
-
variant_info=None,
|
| 286 |
-
**kwargs):
|
| 287 |
-
# sanity check
|
| 288 |
-
assert isinstance(steps, (int, torch.LongTensor))
|
| 289 |
-
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
|
| 290 |
-
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
|
| 291 |
-
assert discretization in (None, 'leading', 'linspace', 'trailing')
|
| 292 |
-
assert discard_penultimate_step in (None, True, False)
|
| 293 |
-
assert return_intermediate in (None, 'x0', 'xt')
|
| 294 |
-
|
| 295 |
-
# function of diffusion solver
|
| 296 |
-
solver_fn = {
|
| 297 |
-
'heun': sample_heun,
|
| 298 |
-
'dpmpp_2m_sde': sample_dpmpp_2m_sde
|
| 299 |
-
}[solver]
|
| 300 |
-
|
| 301 |
-
# options
|
| 302 |
-
schedule = 'karras' if 'karras' in solver else None
|
| 303 |
-
discretization = discretization or 'linspace'
|
| 304 |
-
seed = seed if seed >= 0 else random.randint(0, 2**31)
|
| 305 |
-
if isinstance(steps, torch.LongTensor):
|
| 306 |
-
discard_penultimate_step = False
|
| 307 |
-
if discard_penultimate_step is None:
|
| 308 |
-
discard_penultimate_step = True if solver in (
|
| 309 |
-
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
|
| 310 |
-
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
|
| 311 |
-
|
| 312 |
-
# function for denoising xt to get x0
|
| 313 |
-
intermediates = []
|
| 314 |
-
|
| 315 |
-
def model_fn(xt, sigma, variant_info=None):
|
| 316 |
-
# denoising
|
| 317 |
-
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
| 318 |
-
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
|
| 319 |
-
guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
|
| 320 |
-
|
| 321 |
-
# collect intermediate outputs
|
| 322 |
-
if return_intermediate == 'xt':
|
| 323 |
-
intermediates.append(xt)
|
| 324 |
-
elif return_intermediate == 'x0':
|
| 325 |
-
print('add intermediate outputs x0')
|
| 326 |
-
intermediates.append(x0)
|
| 327 |
-
return x0
|
| 328 |
-
|
| 329 |
-
# mask_cond = model_kwargs[3]['mask_cond']
|
| 330 |
-
def model_chunk_fn(xt, sigma, variant_info=None):
|
| 331 |
-
# denoising
|
| 332 |
-
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
| 333 |
-
O_LEN = chunk_inds[0][-1]-chunk_inds[1][0]
|
| 334 |
-
cut_f_ind = O_LEN//2
|
| 335 |
-
|
| 336 |
-
results_list = []
|
| 337 |
-
for i in range(len(chunk_inds)):
|
| 338 |
-
ind_start, ind_end = chunk_inds[i]
|
| 339 |
-
xt_chunk = xt[:,:,ind_start:ind_end].clone()
|
| 340 |
-
model_kwargs[2]['hint_chunk'] = model_kwargs[2]['hint'][:,:,ind_start:ind_end].clone() # new added
|
| 341 |
-
cur_f = xt_chunk.size(2)
|
| 342 |
-
# model_kwargs[3]['mask_cond'] = mask_cond[:,ind_start:ind_end].clone()
|
| 343 |
-
x0_chunk = self.denoise(xt_chunk, t, None, model, model_kwargs, guide_scale,
|
| 344 |
-
guide_rescale, clamp, percentile, variant_info=variant_info)[-2]
|
| 345 |
-
if i == 0:
|
| 346 |
-
results_list.append(x0_chunk[:,:,:cur_f+cut_f_ind-O_LEN])
|
| 347 |
-
elif i == len(chunk_inds)-1:
|
| 348 |
-
results_list.append(x0_chunk[:,:,cut_f_ind:])
|
| 349 |
-
else:
|
| 350 |
-
results_list.append(x0_chunk[:,:,cut_f_ind:cur_f+cut_f_ind-O_LEN])
|
| 351 |
-
x0 = torch.concat(results_list, dim=2)
|
| 352 |
-
torch.cuda.empty_cache()
|
| 353 |
-
return x0
|
| 354 |
-
|
| 355 |
-
# get timesteps
|
| 356 |
-
if isinstance(steps, int):
|
| 357 |
-
steps += 1 if discard_penultimate_step else 0
|
| 358 |
-
t_max = self.num_timesteps - 1 if t_max is None else t_max
|
| 359 |
-
t_min = 0 if t_min is None else t_min
|
| 360 |
-
|
| 361 |
-
# discretize timesteps
|
| 362 |
-
if discretization == 'leading':
|
| 363 |
-
steps = torch.arange(t_min, t_max + 1,
|
| 364 |
-
(t_max - t_min + 1) / steps).flip(0)
|
| 365 |
-
elif discretization == 'linspace':
|
| 366 |
-
steps = torch.linspace(t_max, t_min, steps)
|
| 367 |
-
elif discretization == 'trailing':
|
| 368 |
-
steps = torch.arange(t_max, t_min - 1,
|
| 369 |
-
-((t_max - t_min + 1) / steps))
|
| 370 |
-
if solver_mode == 'fast':
|
| 371 |
-
t_mid = 500
|
| 372 |
-
steps1 = torch.arange(t_max, t_mid - 1,
|
| 373 |
-
-((t_max - t_mid + 1) / 4))
|
| 374 |
-
steps2 = torch.arange(t_mid, t_min - 1,
|
| 375 |
-
-((t_mid - t_min + 1) / 11))
|
| 376 |
-
steps = torch.concat([steps1, steps2])
|
| 377 |
-
else:
|
| 378 |
-
raise NotImplementedError(
|
| 379 |
-
f'{discretization} discretization not implemented')
|
| 380 |
-
steps = steps.clamp_(t_min, t_max)
|
| 381 |
-
steps = torch.as_tensor(
|
| 382 |
-
steps, dtype=torch.float32, device=noise.device)
|
| 383 |
-
|
| 384 |
-
# get sigmas
|
| 385 |
-
sigmas = self._t_to_sigma(steps)
|
| 386 |
-
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
| 387 |
-
if schedule == 'karras':
|
| 388 |
-
if sigmas[0] == float('inf'):
|
| 389 |
-
sigmas = karras_schedule(
|
| 390 |
-
n=len(steps) - 1,
|
| 391 |
-
sigma_min=sigmas[sigmas > 0].min().item(),
|
| 392 |
-
sigma_max=sigmas[sigmas < float('inf')].max().item(),
|
| 393 |
-
rho=7.).to(sigmas)
|
| 394 |
-
sigmas = torch.cat([
|
| 395 |
-
sigmas.new_tensor([float('inf')]), sigmas,
|
| 396 |
-
sigmas.new_zeros([1])
|
| 397 |
-
])
|
| 398 |
-
else:
|
| 399 |
-
sigmas = karras_schedule(
|
| 400 |
-
n=len(steps),
|
| 401 |
-
sigma_min=sigmas[sigmas > 0].min().item(),
|
| 402 |
-
sigma_max=sigmas.max().item(),
|
| 403 |
-
rho=7.).to(sigmas)
|
| 404 |
-
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
| 405 |
-
if discard_penultimate_step:
|
| 406 |
-
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
fn = model_chunk_fn if chunk_inds is not None else model_fn
|
| 410 |
-
x0 = solver_fn(
|
| 411 |
-
noise, fn, sigmas, variant_info=variant_info, show_progress=show_progress, **kwargs)
|
| 412 |
-
return (x0, intermediates) if return_intermediate is not None else x0
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
def _sigma_to_t(self, sigma):
|
| 416 |
-
if sigma == float('inf'):
|
| 417 |
-
t = torch.full_like(sigma, len(self.sigmas) - 1)
|
| 418 |
-
else:
|
| 419 |
-
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
|
| 420 |
-
(1 - self.sigmas**2)).log().to(sigma)
|
| 421 |
-
log_sigma = sigma.log()
|
| 422 |
-
dists = log_sigma - log_sigmas[:, None]
|
| 423 |
-
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
|
| 424 |
-
max=log_sigmas.shape[0] - 2)
|
| 425 |
-
high_idx = low_idx + 1
|
| 426 |
-
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
|
| 427 |
-
w = (low - log_sigma) / (low - high)
|
| 428 |
-
w = w.clamp(0, 1)
|
| 429 |
-
t = (1 - w) * low_idx + w * high_idx
|
| 430 |
-
t = t.view(sigma.shape)
|
| 431 |
-
if t.ndim == 0:
|
| 432 |
-
t = t.unsqueeze(0)
|
| 433 |
-
return t
|
| 434 |
-
|
| 435 |
-
def _t_to_sigma(self, t):
|
| 436 |
-
t = t.float()
|
| 437 |
-
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
| 438 |
-
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
|
| 439 |
-
(1 - self.sigmas**2)).log().to(t)
|
| 440 |
-
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
|
| 441 |
-
log_sigma[torch.isnan(log_sigma)
|
| 442 |
-
| torch.isinf(log_sigma)] = float('inf')
|
| 443 |
-
return log_sigma.exp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/diffusion/schedules_sdedit.py
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def betas_to_sigmas(betas):
|
| 9 |
-
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def sigmas_to_betas(sigmas):
|
| 13 |
-
square_alphas = 1 - sigmas**2
|
| 14 |
-
betas = 1 - torch.cat(
|
| 15 |
-
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
|
| 16 |
-
return betas
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def logsnrs_to_sigmas(logsnrs):
|
| 20 |
-
return torch.sqrt(torch.sigmoid(-logsnrs))
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def sigmas_to_logsnrs(sigmas):
|
| 24 |
-
square_sigmas = sigmas**2
|
| 25 |
-
return torch.log(square_sigmas / (1 - square_sigmas))
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
|
| 29 |
-
t_min = math.atan(math.exp(-0.5 * logsnr_min))
|
| 30 |
-
t_max = math.atan(math.exp(-0.5 * logsnr_max))
|
| 31 |
-
t = torch.linspace(1, 0, n)
|
| 32 |
-
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
|
| 33 |
-
return logsnrs
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
|
| 37 |
-
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
|
| 38 |
-
logsnrs += 2 * math.log(1 / scale)
|
| 39 |
-
return logsnrs
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def _logsnr_cosine_interp(n,
|
| 43 |
-
logsnr_min=-15,
|
| 44 |
-
logsnr_max=15,
|
| 45 |
-
scale_min=2,
|
| 46 |
-
scale_max=4):
|
| 47 |
-
t = torch.linspace(1, 0, n)
|
| 48 |
-
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
|
| 49 |
-
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
|
| 50 |
-
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
|
| 51 |
-
return logsnrs
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
| 55 |
-
ramp = torch.linspace(1, 0, n)
|
| 56 |
-
min_inv_rho = sigma_min**(1 / rho)
|
| 57 |
-
max_inv_rho = sigma_max**(1 / rho)
|
| 58 |
-
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
|
| 59 |
-
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
|
| 60 |
-
return sigmas
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def logsnr_cosine_interp_schedule(n,
|
| 64 |
-
logsnr_min=-15,
|
| 65 |
-
logsnr_max=15,
|
| 66 |
-
scale_min=2,
|
| 67 |
-
scale_max=4):
|
| 68 |
-
return logsnrs_to_sigmas(
|
| 69 |
-
_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def noise_schedule(schedule='logsnr_cosine_interp',
|
| 73 |
-
n=1000,
|
| 74 |
-
zero_terminal_snr=False,
|
| 75 |
-
**kwargs):
|
| 76 |
-
# compute sigmas
|
| 77 |
-
sigmas = {
|
| 78 |
-
'logsnr_cosine_interp': logsnr_cosine_interp_schedule
|
| 79 |
-
}[schedule](n, **kwargs)
|
| 80 |
-
|
| 81 |
-
# post-processing
|
| 82 |
-
if zero_terminal_snr and sigmas.max() != 1.0:
|
| 83 |
-
scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
|
| 84 |
-
sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
|
| 85 |
-
return sigmas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/diffusion/solvers_sdedit.py
DELETED
|
@@ -1,204 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torchsde
|
| 5 |
-
from tqdm.auto import trange
|
| 6 |
-
|
| 7 |
-
from video_to_video.utils.logger import get_logger
|
| 8 |
-
|
| 9 |
-
logger = get_logger()
|
| 10 |
-
|
| 11 |
-
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
| 12 |
-
"""
|
| 13 |
-
Calculates the noise level (sigma_down) to step down to and the amount
|
| 14 |
-
of noise to add (sigma_up) when doing an ancestral sampling step.
|
| 15 |
-
"""
|
| 16 |
-
if not eta:
|
| 17 |
-
return sigma_to, 0.
|
| 18 |
-
sigma_up = min(
|
| 19 |
-
sigma_to,
|
| 20 |
-
eta * (
|
| 21 |
-
sigma_to**2 * # noqa
|
| 22 |
-
(sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
|
| 23 |
-
sigma_down = (sigma_to**2 - sigma_up**2)**0.5
|
| 24 |
-
return sigma_down, sigma_up
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def get_scalings(sigma):
|
| 28 |
-
c_out = -sigma
|
| 29 |
-
c_in = 1 / (sigma**2 + 1.**2)**0.5
|
| 30 |
-
return c_out, c_in
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
@torch.no_grad()
|
| 34 |
-
def sample_heun(noise,
|
| 35 |
-
model,
|
| 36 |
-
sigmas,
|
| 37 |
-
s_churn=0.,
|
| 38 |
-
s_tmin=0.,
|
| 39 |
-
s_tmax=float('inf'),
|
| 40 |
-
s_noise=1.,
|
| 41 |
-
show_progress=True):
|
| 42 |
-
"""
|
| 43 |
-
Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
|
| 44 |
-
"""
|
| 45 |
-
x = noise * sigmas[0]
|
| 46 |
-
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
| 47 |
-
gamma = 0.
|
| 48 |
-
if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
|
| 49 |
-
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
|
| 50 |
-
eps = torch.randn_like(x) * s_noise
|
| 51 |
-
sigma_hat = sigmas[i] * (gamma + 1)
|
| 52 |
-
if gamma > 0:
|
| 53 |
-
x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
|
| 54 |
-
if sigmas[i] == float('inf'):
|
| 55 |
-
# Euler method
|
| 56 |
-
denoised = model(noise, sigma_hat)
|
| 57 |
-
x = denoised + sigmas[i + 1] * (gamma + 1) * noise
|
| 58 |
-
else:
|
| 59 |
-
_, c_in = get_scalings(sigma_hat)
|
| 60 |
-
denoised = model(x * c_in, sigma_hat)
|
| 61 |
-
d = (x - denoised) / sigma_hat
|
| 62 |
-
dt = sigmas[i + 1] - sigma_hat
|
| 63 |
-
if sigmas[i + 1] == 0:
|
| 64 |
-
# Euler method
|
| 65 |
-
x = x + d * dt
|
| 66 |
-
else:
|
| 67 |
-
# Heun's method
|
| 68 |
-
x_2 = x + d * dt
|
| 69 |
-
_, c_in = get_scalings(sigmas[i + 1])
|
| 70 |
-
denoised_2 = model(x_2 * c_in, sigmas[i + 1])
|
| 71 |
-
d_2 = (x_2 - denoised_2) / sigmas[i + 1]
|
| 72 |
-
d_prime = (d + d_2) / 2
|
| 73 |
-
x = x + d_prime * dt
|
| 74 |
-
return x
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class BatchedBrownianTree:
|
| 78 |
-
"""
|
| 79 |
-
A wrapper around torchsde.BrownianTree that enables batches of entropy.
|
| 80 |
-
"""
|
| 81 |
-
|
| 82 |
-
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
| 83 |
-
t0, t1, self.sign = self.sort(t0, t1)
|
| 84 |
-
w0 = kwargs.get('w0', torch.zeros_like(x))
|
| 85 |
-
if seed is None:
|
| 86 |
-
seed = torch.randint(0, 2**63 - 1, []).item()
|
| 87 |
-
self.batched = True
|
| 88 |
-
try:
|
| 89 |
-
assert len(seed) == x.shape[0]
|
| 90 |
-
w0 = w0[0]
|
| 91 |
-
except TypeError:
|
| 92 |
-
seed = [seed]
|
| 93 |
-
self.batched = False
|
| 94 |
-
self.trees = [
|
| 95 |
-
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
|
| 96 |
-
for s in seed
|
| 97 |
-
]
|
| 98 |
-
|
| 99 |
-
@staticmethod
|
| 100 |
-
def sort(a, b):
|
| 101 |
-
return (a, b, 1) if a < b else (b, a, -1)
|
| 102 |
-
|
| 103 |
-
def __call__(self, t0, t1):
|
| 104 |
-
t0, t1, sign = self.sort(t0, t1)
|
| 105 |
-
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
|
| 106 |
-
self.sign * sign)
|
| 107 |
-
return w if self.batched else w[0]
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class BrownianTreeNoiseSampler:
|
| 111 |
-
"""
|
| 112 |
-
A noise sampler backed by a torchsde.BrownianTree.
|
| 113 |
-
|
| 114 |
-
Args:
|
| 115 |
-
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
| 116 |
-
random samples.
|
| 117 |
-
sigma_min (float): The low end of the valid interval.
|
| 118 |
-
sigma_max (float): The high end of the valid interval.
|
| 119 |
-
seed (int or List[int]): The random seed. If a list of seeds is
|
| 120 |
-
supplied instead of a single integer, then the noise sampler will
|
| 121 |
-
use one BrownianTree per batch item, each with its own seed.
|
| 122 |
-
transform (callable): A function that maps sigma to the sampler's
|
| 123 |
-
internal timestep.
|
| 124 |
-
"""
|
| 125 |
-
|
| 126 |
-
def __init__(self,
|
| 127 |
-
x,
|
| 128 |
-
sigma_min,
|
| 129 |
-
sigma_max,
|
| 130 |
-
seed=None,
|
| 131 |
-
transform=lambda x: x):
|
| 132 |
-
self.transform = transform
|
| 133 |
-
t0 = self.transform(torch.as_tensor(sigma_min))
|
| 134 |
-
t1 = self.transform(torch.as_tensor(sigma_max))
|
| 135 |
-
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
| 136 |
-
|
| 137 |
-
def __call__(self, sigma, sigma_next):
|
| 138 |
-
t0 = self.transform(torch.as_tensor(sigma))
|
| 139 |
-
t1 = self.transform(torch.as_tensor(sigma_next))
|
| 140 |
-
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
@torch.no_grad()
|
| 144 |
-
def sample_dpmpp_2m_sde(noise,
|
| 145 |
-
model,
|
| 146 |
-
sigmas,
|
| 147 |
-
eta=1.,
|
| 148 |
-
s_noise=1.,
|
| 149 |
-
solver_type='midpoint',
|
| 150 |
-
show_progress=True,
|
| 151 |
-
variant_info=None):
|
| 152 |
-
"""
|
| 153 |
-
DPM-Solver++ (2M) SDE.
|
| 154 |
-
"""
|
| 155 |
-
assert solver_type in {'heun', 'midpoint'}
|
| 156 |
-
|
| 157 |
-
x = noise * sigmas[0]
|
| 158 |
-
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
|
| 159 |
-
sigmas < float('inf')].max()
|
| 160 |
-
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
|
| 161 |
-
old_denoised = None
|
| 162 |
-
h_last = None
|
| 163 |
-
|
| 164 |
-
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
| 165 |
-
logger.info(f'step: {i}')
|
| 166 |
-
if sigmas[i] == float('inf'):
|
| 167 |
-
# Euler method
|
| 168 |
-
denoised = model(noise, sigmas[i], variant_info=variant_info)
|
| 169 |
-
x = denoised + sigmas[i + 1] * noise
|
| 170 |
-
else:
|
| 171 |
-
_, c_in = get_scalings(sigmas[i])
|
| 172 |
-
denoised = model(x * c_in, sigmas[i], variant_info=variant_info)
|
| 173 |
-
if sigmas[i + 1] == 0:
|
| 174 |
-
# Denoising step
|
| 175 |
-
x = denoised
|
| 176 |
-
else:
|
| 177 |
-
# DPM-Solver++(2M) SDE
|
| 178 |
-
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
| 179 |
-
h = s - t
|
| 180 |
-
eta_h = eta * h
|
| 181 |
-
|
| 182 |
-
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
|
| 183 |
-
(-h - eta_h).expm1().neg() * denoised
|
| 184 |
-
|
| 185 |
-
if old_denoised is not None:
|
| 186 |
-
r = h_last / h
|
| 187 |
-
if solver_type == 'heun':
|
| 188 |
-
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
|
| 189 |
-
(1 / r) * (denoised - old_denoised)
|
| 190 |
-
elif solver_type == 'midpoint':
|
| 191 |
-
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
|
| 192 |
-
(1 / r) * (denoised - old_denoised)
|
| 193 |
-
|
| 194 |
-
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
|
| 195 |
-
i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
| 196 |
-
|
| 197 |
-
old_denoised = denoised
|
| 198 |
-
h_last = h
|
| 199 |
-
|
| 200 |
-
if variant_info is not None and variant_info.get('type') == 'variant1':
|
| 201 |
-
x_long, x_short = x.chunk(2, dim=0)
|
| 202 |
-
x = x_long * (1-variant_info['alpha']) + x_short * variant_info['alpha']
|
| 203 |
-
|
| 204 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/modules/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from .embedder import *
|
| 2 |
-
from .unet_v2v import *
|
| 3 |
-
# from .unet_v2v_deform import *
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/modules/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (206 Bytes)
|
|
|
video_to_video/modules/__pycache__/embedder.cpython-39.pyc
DELETED
|
Binary file (2.58 kB)
|
|
|
video_to_video/modules/__pycache__/t5.cpython-39.pyc
DELETED
|
Binary file (7.07 kB)
|
|
|
video_to_video/modules/__pycache__/unet_v2v.cpython-39.pyc
DELETED
|
Binary file (47.6 kB)
|
|
|
video_to_video/modules/__pycache__/unet_v2v_LocalConv.cpython-39.pyc
DELETED
|
Binary file (47.8 kB)
|
|
|
video_to_video/modules/__pycache__/unet_v2v_deform.cpython-39.pyc
DELETED
|
Binary file (48.2 kB)
|
|
|
video_to_video/modules/embedder.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import open_clip
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torchvision.transforms as T
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class FrozenOpenCLIPEmbedder(nn.Module):
|
| 13 |
-
"""
|
| 14 |
-
Uses the OpenCLIP transformer encoder for text
|
| 15 |
-
"""
|
| 16 |
-
LAYERS = ['last', 'penultimate']
|
| 17 |
-
|
| 18 |
-
def __init__(self,
|
| 19 |
-
pretrained='laion2b_s32b_b79k',
|
| 20 |
-
arch='ViT-H-14',
|
| 21 |
-
device='cuda',
|
| 22 |
-
max_length=77,
|
| 23 |
-
freeze=True,
|
| 24 |
-
layer='penultimate'):
|
| 25 |
-
super().__init__()
|
| 26 |
-
assert layer in self.LAYERS
|
| 27 |
-
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
|
| 28 |
-
|
| 29 |
-
del model.visual
|
| 30 |
-
self.model = model
|
| 31 |
-
self.device = device
|
| 32 |
-
self.max_length = max_length
|
| 33 |
-
|
| 34 |
-
if freeze:
|
| 35 |
-
self.freeze()
|
| 36 |
-
self.layer = layer
|
| 37 |
-
if self.layer == 'last':
|
| 38 |
-
self.layer_idx = 0
|
| 39 |
-
elif self.layer == 'penultimate':
|
| 40 |
-
self.layer_idx = 1
|
| 41 |
-
else:
|
| 42 |
-
raise NotImplementedError()
|
| 43 |
-
|
| 44 |
-
def freeze(self):
|
| 45 |
-
self.model = self.model.eval()
|
| 46 |
-
for param in self.parameters():
|
| 47 |
-
param.requires_grad = False
|
| 48 |
-
|
| 49 |
-
def forward(self, text):
|
| 50 |
-
tokens = open_clip.tokenize(text)
|
| 51 |
-
z = self.encode_with_transformer(tokens.to(self.device))
|
| 52 |
-
return z
|
| 53 |
-
|
| 54 |
-
def encode_with_transformer(self, text):
|
| 55 |
-
x = self.model.token_embedding(text)
|
| 56 |
-
x = x + self.model.positional_embedding
|
| 57 |
-
x = x.permute(1, 0, 2)
|
| 58 |
-
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
| 59 |
-
x = x.permute(1, 0, 2)
|
| 60 |
-
x = self.model.ln_final(x)
|
| 61 |
-
return x
|
| 62 |
-
|
| 63 |
-
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
| 64 |
-
for i, r in enumerate(self.model.transformer.resblocks):
|
| 65 |
-
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
| 66 |
-
break
|
| 67 |
-
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
| 68 |
-
):
|
| 69 |
-
x = checkpoint(r, x, attn_mask)
|
| 70 |
-
else:
|
| 71 |
-
x = r(x, attn_mask=attn_mask)
|
| 72 |
-
return x
|
| 73 |
-
|
| 74 |
-
def encode(self, text):
|
| 75 |
-
return self(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/modules/t5.py
DELETED
|
@@ -1,335 +0,0 @@
|
|
| 1 |
-
# Adapted from PixArt
|
| 2 |
-
#
|
| 3 |
-
# Copyright (C) 2023 PixArt-alpha/PixArt-alpha
|
| 4 |
-
#
|
| 5 |
-
# This program is free software: you can redistribute it and/or modify
|
| 6 |
-
# it under the terms of the GNU Affero General Public License as published
|
| 7 |
-
# by the Free Software Foundation, either version 3 of the License, or
|
| 8 |
-
# (at your option) any later version.
|
| 9 |
-
#
|
| 10 |
-
# This program is distributed in the hope that it will be useful,
|
| 11 |
-
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
-
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
-
# GNU Affero General Public License for more details.
|
| 14 |
-
#
|
| 15 |
-
#
|
| 16 |
-
# This source code is licensed under the license found in the
|
| 17 |
-
# LICENSE file in the root directory of this source tree.
|
| 18 |
-
# --------------------------------------------------------
|
| 19 |
-
# References:
|
| 20 |
-
# PixArt: https://github.com/PixArt-alpha/PixArt-alpha
|
| 21 |
-
# T5: https://github.com/google-research/text-to-text-transfer-transformer
|
| 22 |
-
# --------------------------------------------------------
|
| 23 |
-
|
| 24 |
-
import html
|
| 25 |
-
import re
|
| 26 |
-
|
| 27 |
-
import ftfy
|
| 28 |
-
import torch
|
| 29 |
-
from transformers import AutoTokenizer, T5EncoderModel
|
| 30 |
-
|
| 31 |
-
# from opensora.registry import MODELS
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class T5Embedder:
|
| 35 |
-
def __init__(
|
| 36 |
-
self,
|
| 37 |
-
device,
|
| 38 |
-
from_pretrained=None,
|
| 39 |
-
*,
|
| 40 |
-
cache_dir=None,
|
| 41 |
-
hf_token=None,
|
| 42 |
-
use_text_preprocessing=True,
|
| 43 |
-
t5_model_kwargs=None,
|
| 44 |
-
torch_dtype=None,
|
| 45 |
-
use_offload_folder=None,
|
| 46 |
-
model_max_length=120,
|
| 47 |
-
local_files_only=False,
|
| 48 |
-
):
|
| 49 |
-
self.device = torch.device(device)
|
| 50 |
-
self.torch_dtype = torch_dtype or torch.bfloat16
|
| 51 |
-
self.cache_dir = cache_dir
|
| 52 |
-
|
| 53 |
-
if t5_model_kwargs is None:
|
| 54 |
-
t5_model_kwargs = {
|
| 55 |
-
"low_cpu_mem_usage": True,
|
| 56 |
-
"torch_dtype": self.torch_dtype,
|
| 57 |
-
}
|
| 58 |
-
|
| 59 |
-
if use_offload_folder is not None:
|
| 60 |
-
t5_model_kwargs["offload_folder"] = use_offload_folder
|
| 61 |
-
t5_model_kwargs["device_map"] = {
|
| 62 |
-
"shared": self.device,
|
| 63 |
-
"encoder.embed_tokens": self.device,
|
| 64 |
-
"encoder.block.0": self.device,
|
| 65 |
-
"encoder.block.1": self.device,
|
| 66 |
-
"encoder.block.2": self.device,
|
| 67 |
-
"encoder.block.3": self.device,
|
| 68 |
-
"encoder.block.4": self.device,
|
| 69 |
-
"encoder.block.5": self.device,
|
| 70 |
-
"encoder.block.6": self.device,
|
| 71 |
-
"encoder.block.7": self.device,
|
| 72 |
-
"encoder.block.8": self.device,
|
| 73 |
-
"encoder.block.9": self.device,
|
| 74 |
-
"encoder.block.10": self.device,
|
| 75 |
-
"encoder.block.11": self.device,
|
| 76 |
-
"encoder.block.12": "disk",
|
| 77 |
-
"encoder.block.13": "disk",
|
| 78 |
-
"encoder.block.14": "disk",
|
| 79 |
-
"encoder.block.15": "disk",
|
| 80 |
-
"encoder.block.16": "disk",
|
| 81 |
-
"encoder.block.17": "disk",
|
| 82 |
-
"encoder.block.18": "disk",
|
| 83 |
-
"encoder.block.19": "disk",
|
| 84 |
-
"encoder.block.20": "disk",
|
| 85 |
-
"encoder.block.21": "disk",
|
| 86 |
-
"encoder.block.22": "disk",
|
| 87 |
-
"encoder.block.23": "disk",
|
| 88 |
-
"encoder.final_layer_norm": "disk",
|
| 89 |
-
"encoder.dropout": "disk",
|
| 90 |
-
}
|
| 91 |
-
else:
|
| 92 |
-
t5_model_kwargs["device_map"] = {
|
| 93 |
-
"shared": self.device,
|
| 94 |
-
"encoder": self.device,
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
self.use_text_preprocessing = use_text_preprocessing
|
| 98 |
-
self.hf_token = hf_token
|
| 99 |
-
|
| 100 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 101 |
-
from_pretrained,
|
| 102 |
-
cache_dir=cache_dir,
|
| 103 |
-
local_files_only=local_files_only,
|
| 104 |
-
)
|
| 105 |
-
self.model = T5EncoderModel.from_pretrained(
|
| 106 |
-
from_pretrained,
|
| 107 |
-
cache_dir=cache_dir,
|
| 108 |
-
local_files_only=local_files_only,
|
| 109 |
-
**t5_model_kwargs,
|
| 110 |
-
).eval()
|
| 111 |
-
self.model_max_length = model_max_length
|
| 112 |
-
|
| 113 |
-
def get_text_embeddings(self, texts):
|
| 114 |
-
text_tokens_and_mask = self.tokenizer(
|
| 115 |
-
texts,
|
| 116 |
-
max_length=self.model_max_length,
|
| 117 |
-
padding="max_length",
|
| 118 |
-
truncation=True,
|
| 119 |
-
return_attention_mask=True,
|
| 120 |
-
add_special_tokens=True,
|
| 121 |
-
return_tensors="pt",
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
input_ids = text_tokens_and_mask["input_ids"].to(self.device)
|
| 125 |
-
attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
|
| 126 |
-
with torch.no_grad():
|
| 127 |
-
text_encoder_embs = self.model(
|
| 128 |
-
input_ids=input_ids,
|
| 129 |
-
attention_mask=attention_mask,
|
| 130 |
-
)["last_hidden_state"].detach()
|
| 131 |
-
return text_encoder_embs, attention_mask
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
# @MODELS.register_module("t5")
|
| 135 |
-
class T5Encoder:
|
| 136 |
-
def __init__(
|
| 137 |
-
self,
|
| 138 |
-
from_pretrained=None,
|
| 139 |
-
model_max_length=120,
|
| 140 |
-
device="cuda",
|
| 141 |
-
dtype=torch.float,
|
| 142 |
-
cache_dir=None,
|
| 143 |
-
shardformer=False,
|
| 144 |
-
local_files_only=False,
|
| 145 |
-
):
|
| 146 |
-
assert from_pretrained is not None, "Please specify the path to the T5 model"
|
| 147 |
-
|
| 148 |
-
self.t5 = T5Embedder(
|
| 149 |
-
device=device,
|
| 150 |
-
torch_dtype=dtype,
|
| 151 |
-
from_pretrained=from_pretrained,
|
| 152 |
-
cache_dir=cache_dir,
|
| 153 |
-
model_max_length=model_max_length,
|
| 154 |
-
local_files_only=local_files_only,
|
| 155 |
-
)
|
| 156 |
-
self.t5.model.to(dtype=dtype)
|
| 157 |
-
self.y_embedder = None
|
| 158 |
-
|
| 159 |
-
self.model_max_length = model_max_length
|
| 160 |
-
self.output_dim = self.t5.model.config.d_model
|
| 161 |
-
self.dtype = dtype
|
| 162 |
-
|
| 163 |
-
if shardformer:
|
| 164 |
-
self.shardformer_t5()
|
| 165 |
-
|
| 166 |
-
def shardformer_t5(self):
|
| 167 |
-
from colossalai.shardformer import ShardConfig, ShardFormer
|
| 168 |
-
|
| 169 |
-
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
|
| 170 |
-
from opensora.utils.misc import requires_grad
|
| 171 |
-
|
| 172 |
-
shard_config = ShardConfig(
|
| 173 |
-
tensor_parallel_process_group=None,
|
| 174 |
-
pipeline_stage_manager=None,
|
| 175 |
-
enable_tensor_parallelism=False,
|
| 176 |
-
enable_fused_normalization=False,
|
| 177 |
-
enable_flash_attention=False,
|
| 178 |
-
enable_jit_fused=True,
|
| 179 |
-
enable_sequence_parallelism=False,
|
| 180 |
-
enable_sequence_overlap=False,
|
| 181 |
-
)
|
| 182 |
-
shard_former = ShardFormer(shard_config=shard_config)
|
| 183 |
-
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
|
| 184 |
-
self.t5.model = optim_model.to(self.dtype)
|
| 185 |
-
|
| 186 |
-
# ensure the weights are frozen
|
| 187 |
-
requires_grad(self.t5.model, False)
|
| 188 |
-
|
| 189 |
-
def encode(self, text):
|
| 190 |
-
caption_embs, emb_masks = self.t5.get_text_embeddings(text)
|
| 191 |
-
caption_embs = caption_embs[:, None]
|
| 192 |
-
return dict(y=caption_embs, mask=emb_masks)
|
| 193 |
-
|
| 194 |
-
def null(self, n):
|
| 195 |
-
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
|
| 196 |
-
return null_y
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
def basic_clean(text):
|
| 200 |
-
text = ftfy.fix_text(text)
|
| 201 |
-
text = html.unescape(html.unescape(text))
|
| 202 |
-
return text.strip()
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
BAD_PUNCT_REGEX = re.compile(
|
| 206 |
-
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
| 207 |
-
) # noqa
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
def clean_caption(caption):
|
| 211 |
-
import urllib.parse as ul
|
| 212 |
-
|
| 213 |
-
from bs4 import BeautifulSoup
|
| 214 |
-
|
| 215 |
-
caption = str(caption)
|
| 216 |
-
caption = ul.unquote_plus(caption)
|
| 217 |
-
caption = caption.strip().lower()
|
| 218 |
-
caption = re.sub("<person>", "person", caption)
|
| 219 |
-
# urls:
|
| 220 |
-
caption = re.sub(
|
| 221 |
-
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
| 222 |
-
"",
|
| 223 |
-
caption,
|
| 224 |
-
) # regex for urls
|
| 225 |
-
caption = re.sub(
|
| 226 |
-
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
| 227 |
-
"",
|
| 228 |
-
caption,
|
| 229 |
-
) # regex for urls
|
| 230 |
-
# html:
|
| 231 |
-
caption = BeautifulSoup(caption, features="html.parser").text
|
| 232 |
-
|
| 233 |
-
# @<nickname>
|
| 234 |
-
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
| 235 |
-
|
| 236 |
-
# 31C0—31EF CJK Strokes
|
| 237 |
-
# 31F0—31FF Katakana Phonetic Extensions
|
| 238 |
-
# 3200—32FF Enclosed CJK Letters and Months
|
| 239 |
-
# 3300—33FF CJK Compatibility
|
| 240 |
-
# 3400—4DBF CJK Unified Ideographs Extension A
|
| 241 |
-
# 4DC0—4DFF Yijing Hexagram Symbols
|
| 242 |
-
# 4E00—9FFF CJK Unified Ideographs
|
| 243 |
-
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
| 244 |
-
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
| 245 |
-
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
| 246 |
-
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
| 247 |
-
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
| 248 |
-
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
| 249 |
-
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
| 250 |
-
#######################################################
|
| 251 |
-
|
| 252 |
-
# все виды тире / all types of dash --> "-"
|
| 253 |
-
caption = re.sub(
|
| 254 |
-
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
| 255 |
-
"-",
|
| 256 |
-
caption,
|
| 257 |
-
)
|
| 258 |
-
|
| 259 |
-
# кавычки к одному стандарту
|
| 260 |
-
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
| 261 |
-
caption = re.sub(r"[‘’]", "'", caption)
|
| 262 |
-
|
| 263 |
-
# "
|
| 264 |
-
caption = re.sub(r""?", "", caption)
|
| 265 |
-
# &
|
| 266 |
-
caption = re.sub(r"&", "", caption)
|
| 267 |
-
|
| 268 |
-
# ip adresses:
|
| 269 |
-
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
| 270 |
-
|
| 271 |
-
# article ids:
|
| 272 |
-
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
| 273 |
-
|
| 274 |
-
# \n
|
| 275 |
-
caption = re.sub(r"\\n", " ", caption)
|
| 276 |
-
|
| 277 |
-
# "#123"
|
| 278 |
-
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
| 279 |
-
# "#12345.."
|
| 280 |
-
caption = re.sub(r"#\d{5,}\b", "", caption)
|
| 281 |
-
# "123456.."
|
| 282 |
-
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
| 283 |
-
# filenames:
|
| 284 |
-
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
| 285 |
-
|
| 286 |
-
#
|
| 287 |
-
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
| 288 |
-
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
| 289 |
-
|
| 290 |
-
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
| 291 |
-
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
| 292 |
-
|
| 293 |
-
# this-is-my-cute-cat / this_is_my_cute_cat
|
| 294 |
-
regex2 = re.compile(r"(?:\-|\_)")
|
| 295 |
-
if len(re.findall(regex2, caption)) > 3:
|
| 296 |
-
caption = re.sub(regex2, " ", caption)
|
| 297 |
-
|
| 298 |
-
caption = basic_clean(caption)
|
| 299 |
-
|
| 300 |
-
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
| 301 |
-
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
| 302 |
-
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
| 303 |
-
|
| 304 |
-
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
| 305 |
-
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
| 306 |
-
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
| 307 |
-
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
| 308 |
-
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
| 309 |
-
|
| 310 |
-
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
| 311 |
-
|
| 312 |
-
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
| 313 |
-
|
| 314 |
-
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
| 315 |
-
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
| 316 |
-
caption = re.sub(r"\s+", " ", caption)
|
| 317 |
-
|
| 318 |
-
caption.strip()
|
| 319 |
-
|
| 320 |
-
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
| 321 |
-
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
| 322 |
-
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
| 323 |
-
caption = re.sub(r"^\.\S+$", "", caption)
|
| 324 |
-
|
| 325 |
-
return caption.strip()
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
def text_preprocessing(text, use_text_preprocessing: bool = True):
|
| 329 |
-
if use_text_preprocessing:
|
| 330 |
-
# The exact text cleaning as was in the training stage:
|
| 331 |
-
text = clean_caption(text)
|
| 332 |
-
text = clean_caption(text)
|
| 333 |
-
return text
|
| 334 |
-
else:
|
| 335 |
-
return text.lower().strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/modules/unet_v2v.py
DELETED
|
@@ -1,2332 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
import os
|
| 5 |
-
from abc import abstractmethod
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
import xformers
|
| 11 |
-
import xformers.ops
|
| 12 |
-
from einops import rearrange
|
| 13 |
-
from fairscale.nn.checkpoint import checkpoint_wrapper
|
| 14 |
-
from timm.models.vision_transformer import Mlp
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
USE_TEMPORAL_TRANSFORMER = True
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class CaptionEmbedder(nn.Module):
|
| 21 |
-
"""
|
| 22 |
-
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
|
| 26 |
-
super().__init__()
|
| 27 |
-
self.y_proj = Mlp(
|
| 28 |
-
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
|
| 29 |
-
)
|
| 30 |
-
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
|
| 31 |
-
self.uncond_prob = uncond_prob
|
| 32 |
-
|
| 33 |
-
def token_drop(self, caption, force_drop_ids=None):
|
| 34 |
-
"""
|
| 35 |
-
Drops labels to enable classifier-free guidance.
|
| 36 |
-
"""
|
| 37 |
-
if force_drop_ids is None:
|
| 38 |
-
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
| 39 |
-
else:
|
| 40 |
-
drop_ids = force_drop_ids == 1
|
| 41 |
-
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
| 42 |
-
return caption
|
| 43 |
-
|
| 44 |
-
def forward(self, caption, train, force_drop_ids=None):
|
| 45 |
-
if train:
|
| 46 |
-
assert caption.shape[2:] == self.y_embedding.shape
|
| 47 |
-
use_dropout = self.uncond_prob > 0
|
| 48 |
-
if (train and use_dropout) or (force_drop_ids is not None):
|
| 49 |
-
caption = self.token_drop(caption, force_drop_ids)
|
| 50 |
-
caption = self.y_proj(caption)
|
| 51 |
-
return caption
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class DropPath(nn.Module):
|
| 55 |
-
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
|
| 56 |
-
"""
|
| 57 |
-
|
| 58 |
-
def __init__(self, p):
|
| 59 |
-
super(DropPath, self).__init__()
|
| 60 |
-
self.p = p
|
| 61 |
-
|
| 62 |
-
def forward(self, *args, zero=None, keep=None):
|
| 63 |
-
if not self.training:
|
| 64 |
-
return args[0] if len(args) == 1 else args
|
| 65 |
-
|
| 66 |
-
# params
|
| 67 |
-
x = args[0]
|
| 68 |
-
b = x.size(0)
|
| 69 |
-
n = (torch.rand(b) < self.p).sum()
|
| 70 |
-
|
| 71 |
-
# non-zero and non-keep mask
|
| 72 |
-
mask = x.new_ones(b, dtype=torch.bool)
|
| 73 |
-
if keep is not None:
|
| 74 |
-
mask[keep] = False
|
| 75 |
-
if zero is not None:
|
| 76 |
-
mask[zero] = False
|
| 77 |
-
|
| 78 |
-
# drop-path index
|
| 79 |
-
index = torch.where(mask)[0]
|
| 80 |
-
index = index[torch.randperm(len(index))[:n]]
|
| 81 |
-
if zero is not None:
|
| 82 |
-
index = torch.cat([index, torch.where(zero)[0]], dim=0)
|
| 83 |
-
|
| 84 |
-
# drop-path multiplier
|
| 85 |
-
multiplier = x.new_ones(b)
|
| 86 |
-
multiplier[index] = 0.0
|
| 87 |
-
output = tuple(u * self.broadcast(multiplier, u) for u in args)
|
| 88 |
-
return output[0] if len(args) == 1 else output
|
| 89 |
-
|
| 90 |
-
def broadcast(self, src, dst):
|
| 91 |
-
assert src.size(0) == dst.size(0)
|
| 92 |
-
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
|
| 93 |
-
return src.view(shape)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def sinusoidal_embedding(timesteps, dim):
|
| 97 |
-
# check input
|
| 98 |
-
half = dim // 2
|
| 99 |
-
timesteps = timesteps.float()
|
| 100 |
-
|
| 101 |
-
# compute sinusoidal embedding
|
| 102 |
-
sinusoid = torch.outer(
|
| 103 |
-
timesteps, torch.pow(10000,
|
| 104 |
-
-torch.arange(half).to(timesteps).div(half)))
|
| 105 |
-
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 106 |
-
if dim % 2 != 0:
|
| 107 |
-
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
| 108 |
-
return x
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def exists(x):
|
| 112 |
-
return x is not None
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def default(val, d):
|
| 116 |
-
if exists(val):
|
| 117 |
-
return val
|
| 118 |
-
return d() if callable(d) else d
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def prob_mask_like(shape, prob, device):
|
| 122 |
-
if prob == 1:
|
| 123 |
-
return torch.ones(shape, device=device, dtype=torch.bool)
|
| 124 |
-
elif prob == 0:
|
| 125 |
-
return torch.zeros(shape, device=device, dtype=torch.bool)
|
| 126 |
-
else:
|
| 127 |
-
mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
|
| 128 |
-
# aviod mask all, which will cause find_unused_parameters error
|
| 129 |
-
if mask.all():
|
| 130 |
-
mask[0] = False
|
| 131 |
-
return mask
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
class MemoryEfficientCrossAttention(nn.Module):
|
| 135 |
-
|
| 136 |
-
def __init__(self,
|
| 137 |
-
query_dim,
|
| 138 |
-
context_dim=None,
|
| 139 |
-
heads=8,
|
| 140 |
-
dim_head=64,
|
| 141 |
-
max_bs=16384,
|
| 142 |
-
dropout=0.0):
|
| 143 |
-
super().__init__()
|
| 144 |
-
inner_dim = dim_head * heads
|
| 145 |
-
context_dim = default(context_dim, query_dim)
|
| 146 |
-
|
| 147 |
-
self.max_bs = max_bs
|
| 148 |
-
self.heads = heads
|
| 149 |
-
self.dim_head = dim_head
|
| 150 |
-
|
| 151 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 152 |
-
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
| 153 |
-
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
| 154 |
-
self.to_out = nn.Sequential(
|
| 155 |
-
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
| 156 |
-
self.attention_op: Optional[Any] = None
|
| 157 |
-
|
| 158 |
-
def forward(self, x, context=None, mask=None):
|
| 159 |
-
q = self.to_q(x)
|
| 160 |
-
context = default(context, x)
|
| 161 |
-
k = self.to_k(context)
|
| 162 |
-
v = self.to_v(context)
|
| 163 |
-
|
| 164 |
-
b, _, _ = q.shape
|
| 165 |
-
q, k, v = map(
|
| 166 |
-
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
| 167 |
-
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
| 168 |
-
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
| 169 |
-
(q, k, v),
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
# actually compute the attention, what we cannot get enough of.
|
| 173 |
-
if q.shape[0] > self.max_bs:
|
| 174 |
-
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
|
| 175 |
-
k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0)
|
| 176 |
-
v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0)
|
| 177 |
-
out_list = []
|
| 178 |
-
for q_1, k_1, v_1 in zip(q_list, k_list, v_list):
|
| 179 |
-
out = xformers.ops.memory_efficient_attention(
|
| 180 |
-
q_1, k_1, v_1, attn_bias=None, op=self.attention_op)
|
| 181 |
-
out_list.append(out)
|
| 182 |
-
out = torch.cat(out_list, dim=0)
|
| 183 |
-
else:
|
| 184 |
-
out = xformers.ops.memory_efficient_attention(
|
| 185 |
-
q, k, v, attn_bias=None, op=self.attention_op)
|
| 186 |
-
|
| 187 |
-
if exists(mask):
|
| 188 |
-
raise NotImplementedError
|
| 189 |
-
out = (
|
| 190 |
-
out.unsqueeze(0).reshape(
|
| 191 |
-
b, self.heads, out.shape[1],
|
| 192 |
-
self.dim_head).permute(0, 2, 1,
|
| 193 |
-
3).reshape(b, out.shape[1],
|
| 194 |
-
self.heads * self.dim_head))
|
| 195 |
-
return self.to_out(out)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
class RelativePositionBias(nn.Module):
|
| 199 |
-
|
| 200 |
-
def __init__(self, heads=8, num_buckets=32, max_distance=128):
|
| 201 |
-
super().__init__()
|
| 202 |
-
self.num_buckets = num_buckets
|
| 203 |
-
self.max_distance = max_distance
|
| 204 |
-
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
| 205 |
-
|
| 206 |
-
@staticmethod
|
| 207 |
-
def _relative_position_bucket(relative_position,
|
| 208 |
-
num_buckets=32,
|
| 209 |
-
max_distance=128):
|
| 210 |
-
ret = 0
|
| 211 |
-
n = -relative_position
|
| 212 |
-
|
| 213 |
-
num_buckets //= 2
|
| 214 |
-
ret += (n < 0).long() * num_buckets
|
| 215 |
-
n = torch.abs(n)
|
| 216 |
-
|
| 217 |
-
max_exact = num_buckets // 2
|
| 218 |
-
is_small = n < max_exact
|
| 219 |
-
|
| 220 |
-
val_if_large = max_exact + (
|
| 221 |
-
torch.log(n.float() / max_exact)
|
| 222 |
-
/ math.log(max_distance / max_exact) * # noqa
|
| 223 |
-
(num_buckets - max_exact)).long()
|
| 224 |
-
val_if_large = torch.min(
|
| 225 |
-
val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
| 226 |
-
|
| 227 |
-
ret += torch.where(is_small, n, val_if_large)
|
| 228 |
-
return ret
|
| 229 |
-
|
| 230 |
-
def forward(self, n, device):
|
| 231 |
-
q_pos = torch.arange(n, dtype=torch.long, device=device)
|
| 232 |
-
k_pos = torch.arange(n, dtype=torch.long, device=device)
|
| 233 |
-
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
| 234 |
-
rp_bucket = self._relative_position_bucket(
|
| 235 |
-
rel_pos,
|
| 236 |
-
num_buckets=self.num_buckets,
|
| 237 |
-
max_distance=self.max_distance)
|
| 238 |
-
values = self.relative_attention_bias(rp_bucket)
|
| 239 |
-
return rearrange(values, 'i j h -> h i j')
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
class SpatialTransformer(nn.Module):
|
| 243 |
-
"""
|
| 244 |
-
Transformer block for image-like data.
|
| 245 |
-
First, project the input (aka embedding)
|
| 246 |
-
and reshape to b, t, d.
|
| 247 |
-
Then apply standard transformer action.
|
| 248 |
-
Finally, reshape to image
|
| 249 |
-
NEW: use_linear for more efficiency instead of the 1x1 convs
|
| 250 |
-
"""
|
| 251 |
-
|
| 252 |
-
def __init__(self,
|
| 253 |
-
in_channels,
|
| 254 |
-
n_heads,
|
| 255 |
-
d_head,
|
| 256 |
-
depth=1,
|
| 257 |
-
dropout=0.,
|
| 258 |
-
context_dim=None,
|
| 259 |
-
disable_self_attn=False,
|
| 260 |
-
use_linear=False,
|
| 261 |
-
use_checkpoint=True,
|
| 262 |
-
is_ctrl=False):
|
| 263 |
-
super().__init__()
|
| 264 |
-
if exists(context_dim) and not isinstance(context_dim, list):
|
| 265 |
-
context_dim = [context_dim]
|
| 266 |
-
self.in_channels = in_channels
|
| 267 |
-
inner_dim = n_heads * d_head
|
| 268 |
-
self.norm = torch.nn.GroupNorm(
|
| 269 |
-
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 270 |
-
if not use_linear:
|
| 271 |
-
self.proj_in = nn.Conv2d(
|
| 272 |
-
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 273 |
-
else:
|
| 274 |
-
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 275 |
-
|
| 276 |
-
self.transformer_blocks = nn.ModuleList([
|
| 277 |
-
BasicTransformerBlock(
|
| 278 |
-
inner_dim,
|
| 279 |
-
n_heads,
|
| 280 |
-
d_head,
|
| 281 |
-
dropout=dropout,
|
| 282 |
-
context_dim=context_dim[d],
|
| 283 |
-
disable_self_attn=disable_self_attn,
|
| 284 |
-
checkpoint=use_checkpoint,
|
| 285 |
-
local_type='space',
|
| 286 |
-
is_ctrl=is_ctrl) for d in range(depth)
|
| 287 |
-
])
|
| 288 |
-
if not use_linear:
|
| 289 |
-
self.proj_out = zero_module(
|
| 290 |
-
nn.Conv2d(
|
| 291 |
-
inner_dim, in_channels, kernel_size=1, stride=1,
|
| 292 |
-
padding=0))
|
| 293 |
-
else:
|
| 294 |
-
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
| 295 |
-
self.use_linear = use_linear
|
| 296 |
-
|
| 297 |
-
def forward(self, x, context=None):
|
| 298 |
-
# note: if no context is given, cross-attention defaults to self-attention
|
| 299 |
-
if not isinstance(context, list):
|
| 300 |
-
context = [context]
|
| 301 |
-
_, _, h, w = x.shape
|
| 302 |
-
# print('x shape:', x.shape) # [64, 320, 90, 160]
|
| 303 |
-
x_in = x
|
| 304 |
-
x = self.norm(x)
|
| 305 |
-
if not self.use_linear:
|
| 306 |
-
x = self.proj_in(x)
|
| 307 |
-
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
| 308 |
-
if self.use_linear:
|
| 309 |
-
x = self.proj_in(x)
|
| 310 |
-
for i, block in enumerate(self.transformer_blocks):
|
| 311 |
-
x = block(x, context=context[i], h=h, w=w)
|
| 312 |
-
if self.use_linear:
|
| 313 |
-
x = self.proj_out(x)
|
| 314 |
-
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 315 |
-
if not self.use_linear:
|
| 316 |
-
x = self.proj_out(x)
|
| 317 |
-
return x + x_in
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
class CrossAttention(nn.Module):
|
| 324 |
-
|
| 325 |
-
def __init__(self,
|
| 326 |
-
query_dim,
|
| 327 |
-
context_dim=None,
|
| 328 |
-
heads=8,
|
| 329 |
-
dim_head=64,
|
| 330 |
-
dropout=0.):
|
| 331 |
-
super().__init__()
|
| 332 |
-
inner_dim = dim_head * heads
|
| 333 |
-
context_dim = default(context_dim, query_dim)
|
| 334 |
-
|
| 335 |
-
self.scale = dim_head**-0.5
|
| 336 |
-
self.heads = heads
|
| 337 |
-
|
| 338 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 339 |
-
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
| 340 |
-
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
| 341 |
-
|
| 342 |
-
self.to_out = nn.Sequential(
|
| 343 |
-
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
| 344 |
-
|
| 345 |
-
def forward(self, x, context=None, mask=None):
|
| 346 |
-
h = self.heads
|
| 347 |
-
|
| 348 |
-
q = self.to_q(x)
|
| 349 |
-
context = default(context, x)
|
| 350 |
-
k = self.to_k(context)
|
| 351 |
-
v = self.to_v(context)
|
| 352 |
-
|
| 353 |
-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
| 354 |
-
(q, k, v))
|
| 355 |
-
|
| 356 |
-
# force cast to fp32 to avoid overflowing
|
| 357 |
-
if _ATTN_PRECISION == 'fp32':
|
| 358 |
-
with torch.autocast(enabled=False, device_type='cuda'):
|
| 359 |
-
q, k = q.float(), k.float()
|
| 360 |
-
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 361 |
-
else:
|
| 362 |
-
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 363 |
-
|
| 364 |
-
del q, k
|
| 365 |
-
|
| 366 |
-
if exists(mask):
|
| 367 |
-
mask = rearrange(mask, 'b ... -> b (...)')
|
| 368 |
-
max_neg_value = -torch.finfo(sim.dtype).max
|
| 369 |
-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 370 |
-
sim.masked_fill_(~mask, max_neg_value)
|
| 371 |
-
|
| 372 |
-
# attention, what we cannot get enough of
|
| 373 |
-
sim = sim.softmax(dim=-1)
|
| 374 |
-
|
| 375 |
-
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
| 376 |
-
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
| 377 |
-
return self.to_out(out)
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
class SpatialAttention(nn.Module):
|
| 383 |
-
def __init__(self):
|
| 384 |
-
super(SpatialAttention, self).__init__()
|
| 385 |
-
self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, padding=7 // 2, bias=False)
|
| 386 |
-
self.sigmoid = nn.Sigmoid()
|
| 387 |
-
def forward(self, x):
|
| 388 |
-
|
| 389 |
-
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
| 390 |
-
avg_out = torch.mean(x, dim=1, keepdim=True)
|
| 391 |
-
|
| 392 |
-
weight = torch.cat([max_out, avg_out], dim=1)
|
| 393 |
-
weight = self.conv1(weight)
|
| 394 |
-
|
| 395 |
-
out = self.sigmoid(weight) * x
|
| 396 |
-
return out
|
| 397 |
-
|
| 398 |
-
class TemporalLocalAttention(nn.Module): # b c t h w
|
| 399 |
-
def __init__(self, dim, kernel_size=7):
|
| 400 |
-
super(TemporalLocalAttention, self).__init__()
|
| 401 |
-
self.conv1 = nn.Linear(in_features=2, out_features=1, bias=False)
|
| 402 |
-
self.sigmoid = nn.Sigmoid()
|
| 403 |
-
|
| 404 |
-
def forward(self, x):
|
| 405 |
-
|
| 406 |
-
max_out, _ = torch.max(x, dim=-1, keepdim=True)
|
| 407 |
-
avg_out = torch.mean(x, dim=-1, keepdim=True)
|
| 408 |
-
|
| 409 |
-
weight = torch.cat([max_out, avg_out], dim=-1)
|
| 410 |
-
weight = self.conv1(weight)
|
| 411 |
-
|
| 412 |
-
out = self.sigmoid(weight) * x
|
| 413 |
-
return out
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
class BasicTransformerBlock(nn.Module):
|
| 417 |
-
|
| 418 |
-
def __init__(self,
|
| 419 |
-
dim,
|
| 420 |
-
n_heads,
|
| 421 |
-
d_head,
|
| 422 |
-
dropout=0.,
|
| 423 |
-
context_dim=None,
|
| 424 |
-
gated_ff=True,
|
| 425 |
-
checkpoint=True,
|
| 426 |
-
disable_self_attn=False,
|
| 427 |
-
local_type=None,
|
| 428 |
-
is_ctrl=False):
|
| 429 |
-
super().__init__()
|
| 430 |
-
self.local_type = local_type
|
| 431 |
-
self.is_ctrl = is_ctrl
|
| 432 |
-
attn_cls = MemoryEfficientCrossAttention
|
| 433 |
-
self.disable_self_attn = disable_self_attn
|
| 434 |
-
self.attn1 = attn_cls( # self-attn
|
| 435 |
-
query_dim=dim,
|
| 436 |
-
heads=n_heads,
|
| 437 |
-
dim_head=d_head,
|
| 438 |
-
dropout=dropout,
|
| 439 |
-
context_dim=context_dim if self.disable_self_attn else None)
|
| 440 |
-
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
| 441 |
-
|
| 442 |
-
attn_cls2 = MemoryEfficientCrossAttention
|
| 443 |
-
|
| 444 |
-
self.attn2 = attn_cls2(
|
| 445 |
-
query_dim=dim,
|
| 446 |
-
context_dim=context_dim,
|
| 447 |
-
heads=n_heads,
|
| 448 |
-
dim_head=d_head,
|
| 449 |
-
dropout=dropout)
|
| 450 |
-
self.norm1 = nn.LayerNorm(dim)
|
| 451 |
-
self.norm2 = nn.LayerNorm(dim)
|
| 452 |
-
self.norm3 = nn.LayerNorm(dim)
|
| 453 |
-
self.checkpoint = checkpoint
|
| 454 |
-
|
| 455 |
-
if self.local_type == 'space' and self.is_ctrl:
|
| 456 |
-
self.local1 = SpatialAttention()
|
| 457 |
-
|
| 458 |
-
if self.local_type == 'temp' and self.is_ctrl:
|
| 459 |
-
self.local1 = TemporalLocalAttention(dim=dim)
|
| 460 |
-
self.local2 = TemporalLocalAttention(dim=dim)
|
| 461 |
-
|
| 462 |
-
def forward_(self, x, context=None):
|
| 463 |
-
return checkpoint(self._forward, (x, context), self.parameters(),
|
| 464 |
-
self.checkpoint)
|
| 465 |
-
|
| 466 |
-
def forward(self, x, context=None, h=None, w=None):
|
| 467 |
-
|
| 468 |
-
if self.local_type == 'space' and self.is_ctrl: # [b*t,(hw), c]
|
| 469 |
-
|
| 470 |
-
x_local = rearrange(x, 'b (h w) c -> b c h w', h=h)
|
| 471 |
-
x_local = self.local1(x_local)
|
| 472 |
-
x_local = rearrange(x_local, 'b c h w -> b (h w) c')
|
| 473 |
-
|
| 474 |
-
x = self.attn1(
|
| 475 |
-
self.norm1(x_local),
|
| 476 |
-
context=context if self.disable_self_attn else None) + x
|
| 477 |
-
|
| 478 |
-
x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
|
| 479 |
-
x = self.ff(self.norm3(x)) + x
|
| 480 |
-
|
| 481 |
-
if self.local_type == 'temp' and self.is_ctrl:
|
| 482 |
-
|
| 483 |
-
# x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w)
|
| 484 |
-
x_local = self.local1(x)
|
| 485 |
-
|
| 486 |
-
x = self.attn1(
|
| 487 |
-
self.norm1(x_local),
|
| 488 |
-
context=context if self.disable_self_attn else None) + x
|
| 489 |
-
|
| 490 |
-
# x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w)
|
| 491 |
-
x_local = self.local2(x)
|
| 492 |
-
|
| 493 |
-
x = self.attn2(self.norm2(x_local), context=context) + x
|
| 494 |
-
x = self.ff(self.norm3(x)) + x
|
| 495 |
-
|
| 496 |
-
# elif self.local_type == 'space' and self.is_ctrl:
|
| 497 |
-
# # print('*** use original attention ***')
|
| 498 |
-
# x = self.attn1(
|
| 499 |
-
# self.norm1(x),
|
| 500 |
-
# context=context if self.disable_self_attn else None) + x # self-attention
|
| 501 |
-
|
| 502 |
-
# x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention
|
| 503 |
-
# x = self.ff(self.norm3(x)) + x
|
| 504 |
-
|
| 505 |
-
return x
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
# feedforward
|
| 509 |
-
class GEGLU(nn.Module):
|
| 510 |
-
|
| 511 |
-
def __init__(self, dim_in, dim_out):
|
| 512 |
-
super().__init__()
|
| 513 |
-
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 514 |
-
|
| 515 |
-
def forward(self, x):
|
| 516 |
-
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 517 |
-
return x * F.gelu(gate)
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
def zero_module(module):
|
| 521 |
-
"""
|
| 522 |
-
Zero out the parameters of a module and return it.
|
| 523 |
-
"""
|
| 524 |
-
for p in module.parameters():
|
| 525 |
-
p.detach().zero_()
|
| 526 |
-
return module
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
class FeedForward(nn.Module):
|
| 530 |
-
|
| 531 |
-
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
| 532 |
-
super().__init__()
|
| 533 |
-
inner_dim = int(dim * mult)
|
| 534 |
-
dim_out = default(dim_out, dim)
|
| 535 |
-
project_in = nn.Sequential(nn.Linear(
|
| 536 |
-
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
| 537 |
-
|
| 538 |
-
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
| 539 |
-
nn.Linear(inner_dim, dim_out))
|
| 540 |
-
|
| 541 |
-
def forward(self, x):
|
| 542 |
-
return self.net(x)
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
class Upsample(nn.Module):
|
| 546 |
-
"""
|
| 547 |
-
An upsampling layer with an optional convolution.
|
| 548 |
-
:param channels: channels in the inputs and outputs.
|
| 549 |
-
:param use_conv: a bool determining if a convolution is applied.
|
| 550 |
-
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 551 |
-
upsampling occurs in the inner-two dimensions.
|
| 552 |
-
"""
|
| 553 |
-
|
| 554 |
-
def __init__(self,
|
| 555 |
-
channels,
|
| 556 |
-
use_conv,
|
| 557 |
-
dims=2,
|
| 558 |
-
out_channels=None,
|
| 559 |
-
padding=1):
|
| 560 |
-
super().__init__()
|
| 561 |
-
self.channels = channels
|
| 562 |
-
self.out_channels = out_channels or channels
|
| 563 |
-
self.use_conv = use_conv
|
| 564 |
-
self.dims = dims
|
| 565 |
-
if use_conv:
|
| 566 |
-
self.conv = nn.Conv2d(
|
| 567 |
-
self.channels, self.out_channels, 3, padding=padding)
|
| 568 |
-
|
| 569 |
-
def forward(self, x):
|
| 570 |
-
assert x.shape[1] == self.channels
|
| 571 |
-
if self.dims == 3:
|
| 572 |
-
x = F.interpolate(
|
| 573 |
-
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
| 574 |
-
mode='nearest')
|
| 575 |
-
else:
|
| 576 |
-
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
| 577 |
-
x = x[..., 1:-1, :]
|
| 578 |
-
if self.use_conv:
|
| 579 |
-
x = self.conv(x)
|
| 580 |
-
return x
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
class ResBlock(nn.Module):
|
| 584 |
-
"""
|
| 585 |
-
A residual block that can optionally change the number of channels.
|
| 586 |
-
:param channels: the number of input channels.
|
| 587 |
-
:param emb_channels: the number of timestep embedding channels.
|
| 588 |
-
:param dropout: the rate of dropout.
|
| 589 |
-
:param out_channels: if specified, the number of out channels.
|
| 590 |
-
:param use_conv: if True and out_channels is specified, use a spatial
|
| 591 |
-
convolution instead of a smaller 1x1 convolution to change the
|
| 592 |
-
channels in the skip connection.
|
| 593 |
-
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 594 |
-
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
| 595 |
-
:param up: if True, use this block for upsampling.
|
| 596 |
-
:param down: if True, use this block for downsampling.
|
| 597 |
-
"""
|
| 598 |
-
|
| 599 |
-
def __init__(
|
| 600 |
-
self,
|
| 601 |
-
channels,
|
| 602 |
-
emb_channels,
|
| 603 |
-
dropout,
|
| 604 |
-
out_channels=None,
|
| 605 |
-
use_conv=False,
|
| 606 |
-
use_scale_shift_norm=False,
|
| 607 |
-
dims=2,
|
| 608 |
-
up=False,
|
| 609 |
-
down=False,
|
| 610 |
-
use_temporal_conv=True,
|
| 611 |
-
use_image_dataset=False,
|
| 612 |
-
):
|
| 613 |
-
super().__init__()
|
| 614 |
-
self.channels = channels
|
| 615 |
-
self.emb_channels = emb_channels
|
| 616 |
-
self.dropout = dropout
|
| 617 |
-
self.out_channels = out_channels or channels
|
| 618 |
-
self.use_conv = use_conv
|
| 619 |
-
self.use_scale_shift_norm = use_scale_shift_norm
|
| 620 |
-
self.use_temporal_conv = use_temporal_conv
|
| 621 |
-
|
| 622 |
-
self.in_layers = nn.Sequential(
|
| 623 |
-
nn.GroupNorm(32, channels),
|
| 624 |
-
nn.SiLU(),
|
| 625 |
-
nn.Conv2d(channels, self.out_channels, 3, padding=1),
|
| 626 |
-
)
|
| 627 |
-
|
| 628 |
-
self.updown = up or down
|
| 629 |
-
|
| 630 |
-
if up:
|
| 631 |
-
self.h_upd = Upsample(channels, False, dims)
|
| 632 |
-
self.x_upd = Upsample(channels, False, dims)
|
| 633 |
-
elif down:
|
| 634 |
-
self.h_upd = Downsample(channels, False, dims)
|
| 635 |
-
self.x_upd = Downsample(channels, False, dims)
|
| 636 |
-
else:
|
| 637 |
-
self.h_upd = self.x_upd = nn.Identity()
|
| 638 |
-
|
| 639 |
-
self.emb_layers = nn.Sequential(
|
| 640 |
-
nn.SiLU(),
|
| 641 |
-
nn.Linear(
|
| 642 |
-
emb_channels,
|
| 643 |
-
2 * self.out_channels
|
| 644 |
-
if use_scale_shift_norm else self.out_channels,
|
| 645 |
-
),
|
| 646 |
-
)
|
| 647 |
-
self.out_layers = nn.Sequential(
|
| 648 |
-
nn.GroupNorm(32, self.out_channels),
|
| 649 |
-
nn.SiLU(),
|
| 650 |
-
nn.Dropout(p=dropout),
|
| 651 |
-
zero_module(
|
| 652 |
-
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
| 653 |
-
)
|
| 654 |
-
|
| 655 |
-
if self.out_channels == channels:
|
| 656 |
-
self.skip_connection = nn.Identity()
|
| 657 |
-
elif use_conv:
|
| 658 |
-
self.skip_connection = conv_nd(
|
| 659 |
-
dims, channels, self.out_channels, 3, padding=1)
|
| 660 |
-
else:
|
| 661 |
-
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
|
| 662 |
-
|
| 663 |
-
if self.use_temporal_conv:
|
| 664 |
-
self.temopral_conv = TemporalConvBlock_v2(
|
| 665 |
-
self.out_channels,
|
| 666 |
-
self.out_channels,
|
| 667 |
-
dropout=0.1,
|
| 668 |
-
use_image_dataset=use_image_dataset)
|
| 669 |
-
|
| 670 |
-
def forward(self, x, emb, batch_size, variant_info=None):
|
| 671 |
-
"""
|
| 672 |
-
Apply the block to a Tensor, conditioned on a timestep embedding.
|
| 673 |
-
:param x: an [N x C x ...] Tensor of features.
|
| 674 |
-
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
| 675 |
-
:return: an [N x C x ...] Tensor of outputs.
|
| 676 |
-
"""
|
| 677 |
-
return self._forward(x, emb, batch_size, variant_info)
|
| 678 |
-
|
| 679 |
-
def _forward(self, x, emb, batch_size, variant_info):
|
| 680 |
-
if self.updown:
|
| 681 |
-
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 682 |
-
h = in_rest(x)
|
| 683 |
-
h = self.h_upd(h)
|
| 684 |
-
x = self.x_upd(x)
|
| 685 |
-
h = in_conv(h)
|
| 686 |
-
else:
|
| 687 |
-
h = self.in_layers(x)
|
| 688 |
-
emb_out = self.emb_layers(emb).type(h.dtype)
|
| 689 |
-
while len(emb_out.shape) < len(h.shape):
|
| 690 |
-
emb_out = emb_out[..., None]
|
| 691 |
-
if self.use_scale_shift_norm:
|
| 692 |
-
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| 693 |
-
scale, shift = th.chunk(emb_out, 2, dim=1)
|
| 694 |
-
h = out_norm(h) * (1 + scale) + shift
|
| 695 |
-
h = out_rest(h)
|
| 696 |
-
else:
|
| 697 |
-
h = h + emb_out
|
| 698 |
-
h = self.out_layers(h)
|
| 699 |
-
h = self.skip_connection(x) + h
|
| 700 |
-
|
| 701 |
-
if self.use_temporal_conv:
|
| 702 |
-
h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
|
| 703 |
-
h = self.temopral_conv(h, variant_info=variant_info)
|
| 704 |
-
h = rearrange(h, 'b c f h w -> (b f) c h w')
|
| 705 |
-
return h
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
class Downsample(nn.Module):
|
| 709 |
-
"""
|
| 710 |
-
A downsampling layer with an optional convolution.
|
| 711 |
-
:param channels: channels in the inputs and outputs.
|
| 712 |
-
:param use_conv: a bool determining if a convolution is applied.
|
| 713 |
-
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 714 |
-
downsampling occurs in the inner-two dimensions.
|
| 715 |
-
"""
|
| 716 |
-
|
| 717 |
-
def __init__(self,
|
| 718 |
-
channels,
|
| 719 |
-
use_conv,
|
| 720 |
-
dims=2,
|
| 721 |
-
out_channels=None,
|
| 722 |
-
padding=(2, 1)):
|
| 723 |
-
super().__init__()
|
| 724 |
-
self.channels = channels
|
| 725 |
-
self.out_channels = out_channels or channels
|
| 726 |
-
self.use_conv = use_conv
|
| 727 |
-
self.dims = dims
|
| 728 |
-
stride = 2 if dims != 3 else (1, 2, 2)
|
| 729 |
-
if use_conv:
|
| 730 |
-
self.op = nn.Conv2d(
|
| 731 |
-
self.channels,
|
| 732 |
-
self.out_channels,
|
| 733 |
-
3,
|
| 734 |
-
stride=stride,
|
| 735 |
-
padding=padding)
|
| 736 |
-
else:
|
| 737 |
-
assert self.channels == self.out_channels
|
| 738 |
-
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
| 739 |
-
|
| 740 |
-
def forward(self, x):
|
| 741 |
-
assert x.shape[1] == self.channels
|
| 742 |
-
return self.op(x)
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
class Resample(nn.Module):
|
| 746 |
-
|
| 747 |
-
def __init__(self, in_dim, out_dim, mode):
|
| 748 |
-
assert mode in ['none', 'upsample', 'downsample']
|
| 749 |
-
super(Resample, self).__init__()
|
| 750 |
-
self.in_dim = in_dim
|
| 751 |
-
self.out_dim = out_dim
|
| 752 |
-
self.mode = mode
|
| 753 |
-
|
| 754 |
-
def forward(self, x, reference=None):
|
| 755 |
-
if self.mode == 'upsample':
|
| 756 |
-
assert reference is not None
|
| 757 |
-
x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
|
| 758 |
-
elif self.mode == 'downsample':
|
| 759 |
-
x = F.adaptive_avg_pool2d(
|
| 760 |
-
x, output_size=tuple(u // 2 for u in x.shape[-2:]))
|
| 761 |
-
return x
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
class ResidualBlock(nn.Module):
|
| 765 |
-
|
| 766 |
-
def __init__(self,
|
| 767 |
-
in_dim,
|
| 768 |
-
embed_dim,
|
| 769 |
-
out_dim,
|
| 770 |
-
use_scale_shift_norm=True,
|
| 771 |
-
mode='none',
|
| 772 |
-
dropout=0.0):
|
| 773 |
-
super(ResidualBlock, self).__init__()
|
| 774 |
-
self.in_dim = in_dim
|
| 775 |
-
self.embed_dim = embed_dim
|
| 776 |
-
self.out_dim = out_dim
|
| 777 |
-
self.use_scale_shift_norm = use_scale_shift_norm
|
| 778 |
-
self.mode = mode
|
| 779 |
-
|
| 780 |
-
# layers
|
| 781 |
-
self.layer1 = nn.Sequential(
|
| 782 |
-
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
| 783 |
-
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
| 784 |
-
self.resample = Resample(in_dim, in_dim, mode)
|
| 785 |
-
self.embedding = nn.Sequential(
|
| 786 |
-
nn.SiLU(),
|
| 787 |
-
nn.Linear(embed_dim,
|
| 788 |
-
out_dim * 2 if use_scale_shift_norm else out_dim))
|
| 789 |
-
self.layer2 = nn.Sequential(
|
| 790 |
-
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
| 791 |
-
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
| 792 |
-
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
| 793 |
-
in_dim, out_dim, 1)
|
| 794 |
-
|
| 795 |
-
# zero out the last layer params
|
| 796 |
-
nn.init.zeros_(self.layer2[-1].weight)
|
| 797 |
-
|
| 798 |
-
def forward(self, x, e, reference=None):
|
| 799 |
-
identity = self.resample(x, reference)
|
| 800 |
-
x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
|
| 801 |
-
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
|
| 802 |
-
if self.use_scale_shift_norm:
|
| 803 |
-
scale, shift = e.chunk(2, dim=1)
|
| 804 |
-
x = self.layer2[0](x) * (1 + scale) + shift
|
| 805 |
-
x = self.layer2[1:](x)
|
| 806 |
-
else:
|
| 807 |
-
x = x + e
|
| 808 |
-
x = self.layer2(x)
|
| 809 |
-
x = x + self.shortcut(identity)
|
| 810 |
-
return x
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
class AttentionBlock(nn.Module):
|
| 814 |
-
|
| 815 |
-
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
|
| 816 |
-
# consider head_dim first, then num_heads
|
| 817 |
-
num_heads = dim // head_dim if head_dim else num_heads
|
| 818 |
-
head_dim = dim // num_heads
|
| 819 |
-
assert num_heads * head_dim == dim
|
| 820 |
-
super(AttentionBlock, self).__init__()
|
| 821 |
-
self.dim = dim
|
| 822 |
-
self.context_dim = context_dim
|
| 823 |
-
self.num_heads = num_heads
|
| 824 |
-
self.head_dim = head_dim
|
| 825 |
-
self.scale = math.pow(head_dim, -0.25)
|
| 826 |
-
|
| 827 |
-
# layers
|
| 828 |
-
self.norm = nn.GroupNorm(32, dim)
|
| 829 |
-
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 830 |
-
if context_dim is not None:
|
| 831 |
-
self.context_kv = nn.Linear(context_dim, dim * 2)
|
| 832 |
-
self.proj = nn.Conv2d(dim, dim, 1)
|
| 833 |
-
|
| 834 |
-
# zero out the last layer params
|
| 835 |
-
nn.init.zeros_(self.proj.weight)
|
| 836 |
-
|
| 837 |
-
def forward(self, x, context=None):
|
| 838 |
-
r"""x: [B, C, H, W].
|
| 839 |
-
context: [B, L, C] or None.
|
| 840 |
-
"""
|
| 841 |
-
identity = x
|
| 842 |
-
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
| 843 |
-
|
| 844 |
-
# compute query, key, value
|
| 845 |
-
x = self.norm(x)
|
| 846 |
-
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
| 847 |
-
if context is not None:
|
| 848 |
-
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
|
| 849 |
-
d).permute(0, 2, 3,
|
| 850 |
-
1).chunk(
|
| 851 |
-
2, dim=1)
|
| 852 |
-
k = torch.cat([ck, k], dim=-1)
|
| 853 |
-
v = torch.cat([cv, v], dim=-1)
|
| 854 |
-
|
| 855 |
-
# compute attention
|
| 856 |
-
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
|
| 857 |
-
attn = F.softmax(attn, dim=-1)
|
| 858 |
-
|
| 859 |
-
# gather context
|
| 860 |
-
x = torch.matmul(v, attn.transpose(-1, -2))
|
| 861 |
-
x = x.reshape(b, c, h, w)
|
| 862 |
-
|
| 863 |
-
# output
|
| 864 |
-
x = self.proj(x)
|
| 865 |
-
return x + identity
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
class TemporalAttentionBlock(nn.Module):
|
| 869 |
-
|
| 870 |
-
def __init__(self,
|
| 871 |
-
dim,
|
| 872 |
-
heads=4,
|
| 873 |
-
dim_head=32,
|
| 874 |
-
rotary_emb=None,
|
| 875 |
-
use_image_dataset=False,
|
| 876 |
-
use_sim_mask=False):
|
| 877 |
-
super().__init__()
|
| 878 |
-
# consider num_heads first, as pos_bias needs fixed num_heads
|
| 879 |
-
dim_head = dim // heads
|
| 880 |
-
assert heads * dim_head == dim
|
| 881 |
-
self.use_image_dataset = use_image_dataset
|
| 882 |
-
self.use_sim_mask = use_sim_mask
|
| 883 |
-
|
| 884 |
-
self.scale = dim_head**-0.5
|
| 885 |
-
self.heads = heads
|
| 886 |
-
hidden_dim = dim_head * heads
|
| 887 |
-
|
| 888 |
-
self.norm = nn.GroupNorm(32, dim)
|
| 889 |
-
self.rotary_emb = rotary_emb
|
| 890 |
-
self.to_qkv = nn.Linear(dim, hidden_dim * 3)
|
| 891 |
-
self.to_out = nn.Linear(hidden_dim, dim)
|
| 892 |
-
|
| 893 |
-
def forward(self,
|
| 894 |
-
x,
|
| 895 |
-
pos_bias=None,
|
| 896 |
-
focus_present_mask=None,
|
| 897 |
-
video_mask=None):
|
| 898 |
-
|
| 899 |
-
identity = x
|
| 900 |
-
n, height, device = x.shape[2], x.shape[-2], x.device
|
| 901 |
-
|
| 902 |
-
x = self.norm(x)
|
| 903 |
-
x = rearrange(x, 'b c f h w -> b (h w) f c')
|
| 904 |
-
|
| 905 |
-
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
| 906 |
-
|
| 907 |
-
if exists(focus_present_mask) and focus_present_mask.all():
|
| 908 |
-
# if all batch samples are focusing on present
|
| 909 |
-
# it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
|
| 910 |
-
values = qkv[-1]
|
| 911 |
-
out = self.to_out(values)
|
| 912 |
-
out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
|
| 913 |
-
|
| 914 |
-
return out + identity
|
| 915 |
-
|
| 916 |
-
# split out heads
|
| 917 |
-
q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads)
|
| 918 |
-
k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads)
|
| 919 |
-
v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads)
|
| 920 |
-
|
| 921 |
-
# scale
|
| 922 |
-
|
| 923 |
-
q = q * self.scale
|
| 924 |
-
|
| 925 |
-
# rotate positions into queries and keys for time attention
|
| 926 |
-
if exists(self.rotary_emb):
|
| 927 |
-
q = self.rotary_emb.rotate_queries_or_keys(q)
|
| 928 |
-
k = self.rotary_emb.rotate_queries_or_keys(k)
|
| 929 |
-
|
| 930 |
-
# similarity
|
| 931 |
-
# shape [b (hw) h n n], n=f
|
| 932 |
-
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
|
| 933 |
-
|
| 934 |
-
# relative positional bias
|
| 935 |
-
|
| 936 |
-
if exists(pos_bias):
|
| 937 |
-
sim = sim + pos_bias
|
| 938 |
-
|
| 939 |
-
if (focus_present_mask is None and video_mask is not None):
|
| 940 |
-
# video_mask: [B, n]
|
| 941 |
-
mask = video_mask[:, None, :] * video_mask[:, :, None]
|
| 942 |
-
mask = mask.unsqueeze(1).unsqueeze(1)
|
| 943 |
-
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
| 944 |
-
elif exists(focus_present_mask) and not (~focus_present_mask).all():
|
| 945 |
-
attend_all_mask = torch.ones((n, n),
|
| 946 |
-
device=device,
|
| 947 |
-
dtype=torch.bool)
|
| 948 |
-
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
|
| 949 |
-
|
| 950 |
-
mask = torch.where(
|
| 951 |
-
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
|
| 952 |
-
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
|
| 953 |
-
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
|
| 954 |
-
)
|
| 955 |
-
|
| 956 |
-
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
| 957 |
-
|
| 958 |
-
if self.use_sim_mask:
|
| 959 |
-
sim_mask = torch.tril(
|
| 960 |
-
torch.ones((n, n), device=device, dtype=torch.bool),
|
| 961 |
-
diagonal=0)
|
| 962 |
-
sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
|
| 963 |
-
|
| 964 |
-
# numerical stability
|
| 965 |
-
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 966 |
-
attn = sim.softmax(dim=-1)
|
| 967 |
-
|
| 968 |
-
# aggregate values
|
| 969 |
-
|
| 970 |
-
out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
|
| 971 |
-
out = rearrange(out, '... h n d -> ... n (h d)')
|
| 972 |
-
out = self.to_out(out)
|
| 973 |
-
|
| 974 |
-
out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
|
| 975 |
-
|
| 976 |
-
if self.use_image_dataset:
|
| 977 |
-
out = identity + 0 * out
|
| 978 |
-
else:
|
| 979 |
-
out = identity + out
|
| 980 |
-
return out
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
class TemporalTransformer(nn.Module):
|
| 984 |
-
"""
|
| 985 |
-
Transformer block for image-like data.
|
| 986 |
-
First, project the input (aka embedding)
|
| 987 |
-
and reshape to b, t, d.
|
| 988 |
-
Then apply standard transformer action.
|
| 989 |
-
Finally, reshape to image
|
| 990 |
-
"""
|
| 991 |
-
|
| 992 |
-
def __init__(self,
|
| 993 |
-
in_channels,
|
| 994 |
-
n_heads,
|
| 995 |
-
d_head,
|
| 996 |
-
depth=1,
|
| 997 |
-
dropout=0.,
|
| 998 |
-
context_dim=None,
|
| 999 |
-
disable_self_attn=False,
|
| 1000 |
-
use_linear=False,
|
| 1001 |
-
use_checkpoint=True,
|
| 1002 |
-
only_self_att=True,
|
| 1003 |
-
multiply_zero=False,
|
| 1004 |
-
is_ctrl=False):
|
| 1005 |
-
super().__init__()
|
| 1006 |
-
self.multiply_zero = multiply_zero
|
| 1007 |
-
self.only_self_att = only_self_att
|
| 1008 |
-
self.use_adaptor = False
|
| 1009 |
-
if self.only_self_att:
|
| 1010 |
-
context_dim = None
|
| 1011 |
-
if not isinstance(context_dim, list):
|
| 1012 |
-
context_dim = [context_dim]
|
| 1013 |
-
self.in_channels = in_channels
|
| 1014 |
-
inner_dim = n_heads * d_head
|
| 1015 |
-
self.norm = torch.nn.GroupNorm(
|
| 1016 |
-
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 1017 |
-
if not use_linear:
|
| 1018 |
-
self.proj_in = nn.Conv1d(
|
| 1019 |
-
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 1020 |
-
else:
|
| 1021 |
-
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 1022 |
-
if self.use_adaptor:
|
| 1023 |
-
self.adaptor_in = nn.Linear(frames, frames)
|
| 1024 |
-
|
| 1025 |
-
self.transformer_blocks = nn.ModuleList([
|
| 1026 |
-
BasicTransformerBlock(
|
| 1027 |
-
inner_dim,
|
| 1028 |
-
n_heads,
|
| 1029 |
-
d_head,
|
| 1030 |
-
dropout=dropout,
|
| 1031 |
-
context_dim=context_dim[d],
|
| 1032 |
-
checkpoint=use_checkpoint,
|
| 1033 |
-
local_type='temp',
|
| 1034 |
-
is_ctrl=is_ctrl) for d in range(depth)
|
| 1035 |
-
])
|
| 1036 |
-
if not use_linear:
|
| 1037 |
-
self.proj_out = zero_module(
|
| 1038 |
-
nn.Conv1d(
|
| 1039 |
-
inner_dim, in_channels, kernel_size=1, stride=1,
|
| 1040 |
-
padding=0))
|
| 1041 |
-
else:
|
| 1042 |
-
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
| 1043 |
-
if self.use_adaptor:
|
| 1044 |
-
self.adaptor_out = nn.Linear(frames, frames)
|
| 1045 |
-
self.use_linear = use_linear
|
| 1046 |
-
|
| 1047 |
-
def forward(self, x, context=None):
|
| 1048 |
-
# note: if no context is given, cross-attention defaults to self-attention
|
| 1049 |
-
if self.only_self_att:
|
| 1050 |
-
context = None
|
| 1051 |
-
if not isinstance(context, list):
|
| 1052 |
-
context = [context]
|
| 1053 |
-
b, _, _, h, w = x.shape
|
| 1054 |
-
x_in = x
|
| 1055 |
-
x = self.norm(x)
|
| 1056 |
-
|
| 1057 |
-
if not self.use_linear:
|
| 1058 |
-
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
|
| 1059 |
-
x = self.proj_in(x)
|
| 1060 |
-
if self.use_linear:
|
| 1061 |
-
x = rearrange(
|
| 1062 |
-
x, 'b c f h w -> (b h w) f c').contiguous()
|
| 1063 |
-
x = self.proj_in(x)
|
| 1064 |
-
x = rearrange(
|
| 1065 |
-
x, 'bhw f c -> bhw c f').contiguous()
|
| 1066 |
-
|
| 1067 |
-
# print('x shape:', x.shape) # [28800, 512, 32]
|
| 1068 |
-
if self.only_self_att: # no cross-attention
|
| 1069 |
-
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
|
| 1070 |
-
for i, block in enumerate(self.transformer_blocks):
|
| 1071 |
-
x = block(x, h=h, w=w)
|
| 1072 |
-
# print('x shape:', x.shape) # [43200, 32, 512]
|
| 1073 |
-
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
|
| 1074 |
-
else:
|
| 1075 |
-
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
|
| 1076 |
-
for i, block in enumerate(self.transformer_blocks):
|
| 1077 |
-
context[i] = rearrange(
|
| 1078 |
-
context[i], '(b f) l con -> b f l con',
|
| 1079 |
-
f=self.frames).contiguous()
|
| 1080 |
-
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
| 1081 |
-
for j in range(b):
|
| 1082 |
-
context_i_j = repeat(
|
| 1083 |
-
context[i][j],
|
| 1084 |
-
'f l con -> (f r) l con',
|
| 1085 |
-
r=(h * w) // self.frames,
|
| 1086 |
-
f=self.frames).contiguous()
|
| 1087 |
-
x[j] = block(x[j], context=context_i_j)
|
| 1088 |
-
|
| 1089 |
-
if self.use_linear:
|
| 1090 |
-
x = rearrange(x, 'b hw f c -> (b hw) f c').contiguous()
|
| 1091 |
-
x = self.proj_out(x)
|
| 1092 |
-
x = rearrange(
|
| 1093 |
-
x, '(b h w) f c -> b c f h w', b=b, h=h, w=w).contiguous()
|
| 1094 |
-
if not self.use_linear:
|
| 1095 |
-
# print('x shape:', x.shape) # [2, 21600, 32, 512]
|
| 1096 |
-
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
|
| 1097 |
-
x = self.proj_out(x)
|
| 1098 |
-
x = rearrange(
|
| 1099 |
-
x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
|
| 1100 |
-
|
| 1101 |
-
if self.multiply_zero:
|
| 1102 |
-
x = 0.0 * x + x_in
|
| 1103 |
-
else:
|
| 1104 |
-
x = x + x_in
|
| 1105 |
-
return x
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
class TemporalAttentionMultiBlock(nn.Module):
|
| 1109 |
-
|
| 1110 |
-
def __init__(
|
| 1111 |
-
self,
|
| 1112 |
-
dim,
|
| 1113 |
-
heads=4,
|
| 1114 |
-
dim_head=32,
|
| 1115 |
-
rotary_emb=None,
|
| 1116 |
-
use_image_dataset=False,
|
| 1117 |
-
use_sim_mask=False,
|
| 1118 |
-
temporal_attn_times=1,
|
| 1119 |
-
):
|
| 1120 |
-
super().__init__()
|
| 1121 |
-
self.att_layers = nn.ModuleList([
|
| 1122 |
-
TemporalAttentionBlock(dim, heads, dim_head, rotary_emb,
|
| 1123 |
-
use_image_dataset, use_sim_mask)
|
| 1124 |
-
for _ in range(temporal_attn_times)
|
| 1125 |
-
])
|
| 1126 |
-
|
| 1127 |
-
def forward(self,
|
| 1128 |
-
x,
|
| 1129 |
-
pos_bias=None,
|
| 1130 |
-
focus_present_mask=None,
|
| 1131 |
-
video_mask=None):
|
| 1132 |
-
for layer in self.att_layers:
|
| 1133 |
-
x = layer(x, pos_bias, focus_present_mask, video_mask)
|
| 1134 |
-
return x
|
| 1135 |
-
|
| 1136 |
-
|
| 1137 |
-
class InitTemporalConvBlock(nn.Module):
|
| 1138 |
-
|
| 1139 |
-
def __init__(self,
|
| 1140 |
-
in_dim,
|
| 1141 |
-
out_dim=None,
|
| 1142 |
-
dropout=0.0,
|
| 1143 |
-
use_image_dataset=False):
|
| 1144 |
-
super(InitTemporalConvBlock, self).__init__()
|
| 1145 |
-
if out_dim is None:
|
| 1146 |
-
out_dim = in_dim
|
| 1147 |
-
self.in_dim = in_dim
|
| 1148 |
-
self.out_dim = out_dim
|
| 1149 |
-
self.use_image_dataset = use_image_dataset
|
| 1150 |
-
|
| 1151 |
-
# conv layers
|
| 1152 |
-
self.conv = nn.Sequential(
|
| 1153 |
-
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
| 1154 |
-
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
| 1155 |
-
|
| 1156 |
-
# zero out the last layer params,so the conv block is identity
|
| 1157 |
-
nn.init.zeros_(self.conv[-1].weight)
|
| 1158 |
-
nn.init.zeros_(self.conv[-1].bias)
|
| 1159 |
-
|
| 1160 |
-
def forward(self, x):
|
| 1161 |
-
identity = x
|
| 1162 |
-
x = self.conv(x)
|
| 1163 |
-
if self.use_image_dataset:
|
| 1164 |
-
x = identity + 0 * x
|
| 1165 |
-
else:
|
| 1166 |
-
x = identity + x
|
| 1167 |
-
return x
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
class TemporalConvBlock(nn.Module):
|
| 1171 |
-
|
| 1172 |
-
def __init__(self,
|
| 1173 |
-
in_dim,
|
| 1174 |
-
out_dim=None,
|
| 1175 |
-
dropout=0.0,
|
| 1176 |
-
use_image_dataset=False):
|
| 1177 |
-
super(TemporalConvBlock, self).__init__()
|
| 1178 |
-
if out_dim is None:
|
| 1179 |
-
out_dim = in_dim
|
| 1180 |
-
self.in_dim = in_dim
|
| 1181 |
-
self.out_dim = out_dim
|
| 1182 |
-
self.use_image_dataset = use_image_dataset
|
| 1183 |
-
|
| 1184 |
-
# conv layers
|
| 1185 |
-
self.conv1 = nn.Sequential(
|
| 1186 |
-
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
| 1187 |
-
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
|
| 1188 |
-
self.conv2 = nn.Sequential(
|
| 1189 |
-
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
| 1190 |
-
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
| 1191 |
-
|
| 1192 |
-
# zero out the last layer params,so the conv block is identity
|
| 1193 |
-
nn.init.zeros_(self.conv2[-1].weight)
|
| 1194 |
-
nn.init.zeros_(self.conv2[-1].bias)
|
| 1195 |
-
|
| 1196 |
-
def forward(self, x):
|
| 1197 |
-
identity = x
|
| 1198 |
-
x = self.conv1(x)
|
| 1199 |
-
x = self.conv2(x)
|
| 1200 |
-
if self.use_image_dataset:
|
| 1201 |
-
x = identity + 0 * x
|
| 1202 |
-
else:
|
| 1203 |
-
x = identity + x
|
| 1204 |
-
return x
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
class TemporalConvBlock_v2(nn.Module):
|
| 1208 |
-
|
| 1209 |
-
def __init__(self,
|
| 1210 |
-
in_dim,
|
| 1211 |
-
out_dim=None,
|
| 1212 |
-
dropout=0.0,
|
| 1213 |
-
use_image_dataset=False):
|
| 1214 |
-
super(TemporalConvBlock_v2, self).__init__()
|
| 1215 |
-
if out_dim is None:
|
| 1216 |
-
out_dim = in_dim
|
| 1217 |
-
self.in_dim = in_dim
|
| 1218 |
-
self.out_dim = out_dim
|
| 1219 |
-
self.use_image_dataset = use_image_dataset
|
| 1220 |
-
|
| 1221 |
-
# conv layers
|
| 1222 |
-
self.conv1 = nn.Sequential(
|
| 1223 |
-
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
| 1224 |
-
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
|
| 1225 |
-
self.conv2 = nn.Sequential(
|
| 1226 |
-
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
| 1227 |
-
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
| 1228 |
-
self.conv3 = nn.Sequential(
|
| 1229 |
-
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
| 1230 |
-
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
| 1231 |
-
self.conv4 = nn.Sequential(
|
| 1232 |
-
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
| 1233 |
-
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
|
| 1234 |
-
|
| 1235 |
-
# zero out the last layer params,so the conv block is identity
|
| 1236 |
-
nn.init.zeros_(self.conv4[-1].weight)
|
| 1237 |
-
nn.init.zeros_(self.conv4[-1].bias)
|
| 1238 |
-
|
| 1239 |
-
def forward(self, x, variant_info=None):
|
| 1240 |
-
if variant_info is not None and variant_info.get('type') == 'variant2':
|
| 1241 |
-
# print(x.shape) # torch.Size([1, 320, 32, 90, 160])
|
| 1242 |
-
_, _, f, _, _ = x.shape
|
| 1243 |
-
assert f % 4 == 0, "f must be divisible by 4"
|
| 1244 |
-
x_short = rearrange(x, "b c (n s) h w -> (n b) c s h w", n=4)
|
| 1245 |
-
x_short = self.conv1(x_short)
|
| 1246 |
-
x_short = self.conv2(x_short)
|
| 1247 |
-
x_short = self.conv3(x_short)
|
| 1248 |
-
x_short = self.conv4(x_short)
|
| 1249 |
-
x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
|
| 1250 |
-
|
| 1251 |
-
identity = x
|
| 1252 |
-
x = self.conv1(x)
|
| 1253 |
-
x = self.conv2(x)
|
| 1254 |
-
x = self.conv3(x)
|
| 1255 |
-
x = self.conv4(x)
|
| 1256 |
-
|
| 1257 |
-
x = x * (1-variant_info['alpha']) + x_short * variant_info['alpha']
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
elif variant_info is not None and variant_info.get('type') == 'variant1':
|
| 1261 |
-
identity = x
|
| 1262 |
-
x_long, x_short = x.chunk(2, dim=0)
|
| 1263 |
-
|
| 1264 |
-
x_short = rearrange(x_short, "b c (n s) h w -> (n b) c s h w", n=4)
|
| 1265 |
-
x_short = self.conv1(x_short)
|
| 1266 |
-
x_short = self.conv2(x_short)
|
| 1267 |
-
x_short = self.conv3(x_short)
|
| 1268 |
-
x_short = self.conv4(x_short)
|
| 1269 |
-
x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4)
|
| 1270 |
-
|
| 1271 |
-
x_long = self.conv1(x_long)
|
| 1272 |
-
x_long = self.conv2(x_long)
|
| 1273 |
-
x_long = self.conv3(x_long)
|
| 1274 |
-
x_long = self.conv4(x_long)
|
| 1275 |
-
|
| 1276 |
-
x = torch.cat([x_long, x_short], dim=0)
|
| 1277 |
-
|
| 1278 |
-
|
| 1279 |
-
elif variant_info is None:
|
| 1280 |
-
identity = x
|
| 1281 |
-
x = self.conv1(x)
|
| 1282 |
-
x = self.conv2(x)
|
| 1283 |
-
x = self.conv3(x)
|
| 1284 |
-
x = self.conv4(x)
|
| 1285 |
-
|
| 1286 |
-
|
| 1287 |
-
if self.use_image_dataset:
|
| 1288 |
-
x = identity + 0.0 * x
|
| 1289 |
-
else:
|
| 1290 |
-
x = identity + x
|
| 1291 |
-
return x
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
class Vid2VidSDUNet(nn.Module):
|
| 1295 |
-
|
| 1296 |
-
def __init__(self,
|
| 1297 |
-
in_dim=4,
|
| 1298 |
-
dim=320,
|
| 1299 |
-
y_dim=1024,
|
| 1300 |
-
context_dim=1024,
|
| 1301 |
-
out_dim=4,
|
| 1302 |
-
dim_mult=[1, 2, 4, 4],
|
| 1303 |
-
num_heads=8,
|
| 1304 |
-
head_dim=64,
|
| 1305 |
-
num_res_blocks=2,
|
| 1306 |
-
attn_scales=[1 / 1, 1 / 2, 1 / 4],
|
| 1307 |
-
use_scale_shift_norm=True,
|
| 1308 |
-
dropout=0.1,
|
| 1309 |
-
temporal_attn_times=1,
|
| 1310 |
-
temporal_attention=True,
|
| 1311 |
-
use_checkpoint=True,
|
| 1312 |
-
use_image_dataset=False,
|
| 1313 |
-
use_fps_condition=False,
|
| 1314 |
-
use_sim_mask=False,
|
| 1315 |
-
training=False,
|
| 1316 |
-
inpainting=True):
|
| 1317 |
-
embed_dim = dim * 4
|
| 1318 |
-
num_heads = num_heads if num_heads else dim // 32
|
| 1319 |
-
super(Vid2VidSDUNet, self).__init__()
|
| 1320 |
-
self.in_dim = in_dim
|
| 1321 |
-
self.dim = dim
|
| 1322 |
-
self.y_dim = y_dim
|
| 1323 |
-
self.context_dim = context_dim
|
| 1324 |
-
self.embed_dim = embed_dim
|
| 1325 |
-
self.out_dim = out_dim
|
| 1326 |
-
self.dim_mult = dim_mult
|
| 1327 |
-
# for temporal attention
|
| 1328 |
-
self.num_heads = num_heads
|
| 1329 |
-
# for spatial attention
|
| 1330 |
-
self.head_dim = head_dim
|
| 1331 |
-
self.num_res_blocks = num_res_blocks
|
| 1332 |
-
self.attn_scales = attn_scales
|
| 1333 |
-
self.use_scale_shift_norm = use_scale_shift_norm
|
| 1334 |
-
self.temporal_attn_times = temporal_attn_times
|
| 1335 |
-
self.temporal_attention = temporal_attention
|
| 1336 |
-
self.use_checkpoint = use_checkpoint
|
| 1337 |
-
self.use_image_dataset = use_image_dataset
|
| 1338 |
-
self.use_fps_condition = use_fps_condition
|
| 1339 |
-
self.use_sim_mask = use_sim_mask
|
| 1340 |
-
self.training = training
|
| 1341 |
-
self.inpainting = inpainting
|
| 1342 |
-
|
| 1343 |
-
use_linear_in_temporal = False
|
| 1344 |
-
transformer_depth = 1
|
| 1345 |
-
disabled_sa = False
|
| 1346 |
-
# params
|
| 1347 |
-
enc_dims = [dim * u for u in [1] + dim_mult]
|
| 1348 |
-
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 1349 |
-
shortcut_dims = []
|
| 1350 |
-
scale = 1.0
|
| 1351 |
-
|
| 1352 |
-
# embeddings
|
| 1353 |
-
self.time_embed = nn.Sequential(
|
| 1354 |
-
nn.Linear(dim, embed_dim), nn.SiLU(),
|
| 1355 |
-
nn.Linear(embed_dim, embed_dim))
|
| 1356 |
-
|
| 1357 |
-
if self.use_fps_condition:
|
| 1358 |
-
self.fps_embedding = nn.Sequential(
|
| 1359 |
-
nn.Linear(dim, embed_dim), nn.SiLU(),
|
| 1360 |
-
nn.Linear(embed_dim, embed_dim))
|
| 1361 |
-
nn.init.zeros_(self.fps_embedding[-1].weight)
|
| 1362 |
-
nn.init.zeros_(self.fps_embedding[-1].bias)
|
| 1363 |
-
|
| 1364 |
-
# encoder
|
| 1365 |
-
self.input_blocks = nn.ModuleList()
|
| 1366 |
-
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
| 1367 |
-
# need an initial temporal attention?
|
| 1368 |
-
if temporal_attention:
|
| 1369 |
-
if USE_TEMPORAL_TRANSFORMER:
|
| 1370 |
-
init_block.append(
|
| 1371 |
-
TemporalTransformer(
|
| 1372 |
-
dim,
|
| 1373 |
-
num_heads,
|
| 1374 |
-
head_dim,
|
| 1375 |
-
depth=transformer_depth,
|
| 1376 |
-
context_dim=context_dim,
|
| 1377 |
-
disable_self_attn=disabled_sa,
|
| 1378 |
-
use_linear=use_linear_in_temporal,
|
| 1379 |
-
multiply_zero=use_image_dataset,
|
| 1380 |
-
is_ctrl=True
|
| 1381 |
-
))
|
| 1382 |
-
else:
|
| 1383 |
-
init_block.append(
|
| 1384 |
-
TemporalAttentionMultiBlock(
|
| 1385 |
-
dim,
|
| 1386 |
-
num_heads,
|
| 1387 |
-
head_dim,
|
| 1388 |
-
rotary_emb=self.rotary_emb,
|
| 1389 |
-
temporal_attn_times=temporal_attn_times,
|
| 1390 |
-
use_image_dataset=use_image_dataset))
|
| 1391 |
-
self.input_blocks.append(init_block)
|
| 1392 |
-
shortcut_dims.append(dim)
|
| 1393 |
-
for i, (in_dim,
|
| 1394 |
-
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
| 1395 |
-
for j in range(num_res_blocks):
|
| 1396 |
-
block = nn.ModuleList([
|
| 1397 |
-
ResBlock(
|
| 1398 |
-
in_dim,
|
| 1399 |
-
embed_dim,
|
| 1400 |
-
dropout,
|
| 1401 |
-
out_channels=out_dim,
|
| 1402 |
-
use_scale_shift_norm=False,
|
| 1403 |
-
use_image_dataset=use_image_dataset,
|
| 1404 |
-
)
|
| 1405 |
-
])
|
| 1406 |
-
if scale in attn_scales:
|
| 1407 |
-
block.append(
|
| 1408 |
-
SpatialTransformer(
|
| 1409 |
-
out_dim,
|
| 1410 |
-
out_dim // head_dim,
|
| 1411 |
-
head_dim,
|
| 1412 |
-
depth=1,
|
| 1413 |
-
context_dim=self.context_dim,
|
| 1414 |
-
disable_self_attn=False,
|
| 1415 |
-
use_linear=True,
|
| 1416 |
-
is_ctrl=True
|
| 1417 |
-
))
|
| 1418 |
-
if self.temporal_attention:
|
| 1419 |
-
if USE_TEMPORAL_TRANSFORMER:
|
| 1420 |
-
block.append(
|
| 1421 |
-
TemporalTransformer(
|
| 1422 |
-
out_dim,
|
| 1423 |
-
out_dim // head_dim,
|
| 1424 |
-
head_dim,
|
| 1425 |
-
depth=transformer_depth,
|
| 1426 |
-
context_dim=context_dim,
|
| 1427 |
-
disable_self_attn=disabled_sa,
|
| 1428 |
-
use_linear=use_linear_in_temporal,
|
| 1429 |
-
multiply_zero=use_image_dataset,
|
| 1430 |
-
is_ctrl=True
|
| 1431 |
-
))
|
| 1432 |
-
else:
|
| 1433 |
-
block.append(
|
| 1434 |
-
TemporalAttentionMultiBlock(
|
| 1435 |
-
out_dim,
|
| 1436 |
-
num_heads,
|
| 1437 |
-
head_dim,
|
| 1438 |
-
rotary_emb=self.rotary_emb,
|
| 1439 |
-
use_image_dataset=use_image_dataset,
|
| 1440 |
-
use_sim_mask=use_sim_mask,
|
| 1441 |
-
temporal_attn_times=temporal_attn_times))
|
| 1442 |
-
in_dim = out_dim
|
| 1443 |
-
self.input_blocks.append(block)
|
| 1444 |
-
shortcut_dims.append(out_dim)
|
| 1445 |
-
|
| 1446 |
-
# downsample
|
| 1447 |
-
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
| 1448 |
-
downsample = Downsample(
|
| 1449 |
-
out_dim, True, dims=2, out_channels=out_dim)
|
| 1450 |
-
shortcut_dims.append(out_dim)
|
| 1451 |
-
scale /= 2.0
|
| 1452 |
-
self.input_blocks.append(downsample)
|
| 1453 |
-
|
| 1454 |
-
self.middle_block = nn.ModuleList([
|
| 1455 |
-
ResBlock(
|
| 1456 |
-
out_dim,
|
| 1457 |
-
embed_dim,
|
| 1458 |
-
dropout,
|
| 1459 |
-
use_scale_shift_norm=False,
|
| 1460 |
-
use_image_dataset=use_image_dataset,
|
| 1461 |
-
),
|
| 1462 |
-
SpatialTransformer(
|
| 1463 |
-
out_dim,
|
| 1464 |
-
out_dim // head_dim,
|
| 1465 |
-
head_dim,
|
| 1466 |
-
depth=1,
|
| 1467 |
-
context_dim=self.context_dim,
|
| 1468 |
-
disable_self_attn=False,
|
| 1469 |
-
use_linear=True,
|
| 1470 |
-
is_ctrl=True
|
| 1471 |
-
)
|
| 1472 |
-
])
|
| 1473 |
-
|
| 1474 |
-
if self.temporal_attention:
|
| 1475 |
-
if USE_TEMPORAL_TRANSFORMER:
|
| 1476 |
-
self.middle_block.append(
|
| 1477 |
-
TemporalTransformer(
|
| 1478 |
-
out_dim,
|
| 1479 |
-
out_dim // head_dim,
|
| 1480 |
-
head_dim,
|
| 1481 |
-
depth=transformer_depth,
|
| 1482 |
-
context_dim=context_dim,
|
| 1483 |
-
disable_self_attn=disabled_sa,
|
| 1484 |
-
use_linear=use_linear_in_temporal,
|
| 1485 |
-
multiply_zero=use_image_dataset,
|
| 1486 |
-
is_ctrl=True
|
| 1487 |
-
|
| 1488 |
-
))
|
| 1489 |
-
else:
|
| 1490 |
-
self.middle_block.append(
|
| 1491 |
-
TemporalAttentionMultiBlock(
|
| 1492 |
-
out_dim,
|
| 1493 |
-
num_heads,
|
| 1494 |
-
head_dim,
|
| 1495 |
-
rotary_emb=self.rotary_emb,
|
| 1496 |
-
use_image_dataset=use_image_dataset,
|
| 1497 |
-
use_sim_mask=use_sim_mask,
|
| 1498 |
-
temporal_attn_times=temporal_attn_times))
|
| 1499 |
-
|
| 1500 |
-
self.middle_block.append(
|
| 1501 |
-
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
|
| 1502 |
-
|
| 1503 |
-
# decoder
|
| 1504 |
-
self.output_blocks = nn.ModuleList()
|
| 1505 |
-
for i, (in_dim,
|
| 1506 |
-
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
| 1507 |
-
for j in range(num_res_blocks + 1):
|
| 1508 |
-
block = nn.ModuleList([
|
| 1509 |
-
ResBlock(
|
| 1510 |
-
in_dim + shortcut_dims.pop(),
|
| 1511 |
-
embed_dim,
|
| 1512 |
-
dropout,
|
| 1513 |
-
out_dim,
|
| 1514 |
-
use_scale_shift_norm=False,
|
| 1515 |
-
use_image_dataset=use_image_dataset,
|
| 1516 |
-
)
|
| 1517 |
-
])
|
| 1518 |
-
if scale in attn_scales:
|
| 1519 |
-
block.append(
|
| 1520 |
-
SpatialTransformer(
|
| 1521 |
-
out_dim,
|
| 1522 |
-
out_dim // head_dim,
|
| 1523 |
-
head_dim,
|
| 1524 |
-
depth=1,
|
| 1525 |
-
context_dim=1024,
|
| 1526 |
-
disable_self_attn=False,
|
| 1527 |
-
use_linear=True,
|
| 1528 |
-
is_ctrl=True))
|
| 1529 |
-
if self.temporal_attention:
|
| 1530 |
-
if USE_TEMPORAL_TRANSFORMER:
|
| 1531 |
-
block.append(
|
| 1532 |
-
TemporalTransformer(
|
| 1533 |
-
out_dim,
|
| 1534 |
-
out_dim // head_dim,
|
| 1535 |
-
head_dim,
|
| 1536 |
-
depth=transformer_depth,
|
| 1537 |
-
context_dim=context_dim,
|
| 1538 |
-
disable_self_attn=disabled_sa,
|
| 1539 |
-
use_linear=use_linear_in_temporal,
|
| 1540 |
-
multiply_zero=use_image_dataset,
|
| 1541 |
-
is_ctrl=True))
|
| 1542 |
-
else:
|
| 1543 |
-
block.append(
|
| 1544 |
-
TemporalAttentionMultiBlock(
|
| 1545 |
-
out_dim,
|
| 1546 |
-
num_heads,
|
| 1547 |
-
head_dim,
|
| 1548 |
-
rotary_emb=self.rotary_emb,
|
| 1549 |
-
use_image_dataset=use_image_dataset,
|
| 1550 |
-
use_sim_mask=use_sim_mask,
|
| 1551 |
-
temporal_attn_times=temporal_attn_times))
|
| 1552 |
-
in_dim = out_dim
|
| 1553 |
-
|
| 1554 |
-
# upsample
|
| 1555 |
-
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
| 1556 |
-
upsample = Upsample(
|
| 1557 |
-
out_dim, True, dims=2.0, out_channels=out_dim)
|
| 1558 |
-
scale *= 2.0
|
| 1559 |
-
block.append(upsample)
|
| 1560 |
-
self.output_blocks.append(block)
|
| 1561 |
-
|
| 1562 |
-
# head
|
| 1563 |
-
self.out = nn.Sequential(
|
| 1564 |
-
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
| 1565 |
-
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
| 1566 |
-
|
| 1567 |
-
# zero out the last layer params
|
| 1568 |
-
nn.init.zeros_(self.out[-1].weight)
|
| 1569 |
-
|
| 1570 |
-
def forward(self,
|
| 1571 |
-
x,
|
| 1572 |
-
t,
|
| 1573 |
-
y,
|
| 1574 |
-
x_lr=None,
|
| 1575 |
-
fps=None,
|
| 1576 |
-
video_mask=None,
|
| 1577 |
-
focus_present_mask=None,
|
| 1578 |
-
prob_focus_present=0.,
|
| 1579 |
-
mask_last_frame_num=0):
|
| 1580 |
-
|
| 1581 |
-
batch, c, f, h, w = x.shape
|
| 1582 |
-
device = x.device
|
| 1583 |
-
self.batch = batch
|
| 1584 |
-
|
| 1585 |
-
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
|
| 1586 |
-
if mask_last_frame_num > 0:
|
| 1587 |
-
focus_present_mask = None
|
| 1588 |
-
video_mask[-mask_last_frame_num:] = False
|
| 1589 |
-
else:
|
| 1590 |
-
focus_present_mask = default(
|
| 1591 |
-
focus_present_mask, lambda: prob_mask_like(
|
| 1592 |
-
(batch, ), prob_focus_present, device=device))
|
| 1593 |
-
|
| 1594 |
-
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
|
| 1595 |
-
time_rel_pos_bias = self.time_rel_pos_bias(
|
| 1596 |
-
x.shape[2], device=x.device)
|
| 1597 |
-
else:
|
| 1598 |
-
time_rel_pos_bias = None
|
| 1599 |
-
|
| 1600 |
-
# embeddings
|
| 1601 |
-
e = self.time_embed(sinusoidal_embedding(t, self.dim))
|
| 1602 |
-
context = y
|
| 1603 |
-
|
| 1604 |
-
# repeat f times for spatial e and context
|
| 1605 |
-
e = e.repeat_interleave(repeats=f, dim=0)
|
| 1606 |
-
context = context.repeat_interleave(repeats=f, dim=0)
|
| 1607 |
-
|
| 1608 |
-
# always in shape (b f) c h w, except for temporal layer
|
| 1609 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1610 |
-
# encoder
|
| 1611 |
-
xs = []
|
| 1612 |
-
for ind, block in enumerate(self.input_blocks):
|
| 1613 |
-
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
| 1614 |
-
focus_present_mask, video_mask)
|
| 1615 |
-
xs.append(x)
|
| 1616 |
-
|
| 1617 |
-
# middle
|
| 1618 |
-
for block in self.middle_block:
|
| 1619 |
-
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
| 1620 |
-
focus_present_mask, video_mask)
|
| 1621 |
-
|
| 1622 |
-
# decoder
|
| 1623 |
-
for block in self.output_blocks:
|
| 1624 |
-
x = torch.cat([x, xs.pop()], dim=1)
|
| 1625 |
-
x = self._forward_single(
|
| 1626 |
-
block,
|
| 1627 |
-
x,
|
| 1628 |
-
e,
|
| 1629 |
-
context,
|
| 1630 |
-
time_rel_pos_bias,
|
| 1631 |
-
focus_present_mask,
|
| 1632 |
-
video_mask,
|
| 1633 |
-
reference=xs[-1] if len(xs) > 0 else None)
|
| 1634 |
-
|
| 1635 |
-
# head
|
| 1636 |
-
x = self.out(x)
|
| 1637 |
-
|
| 1638 |
-
# reshape back to (b c f h w)
|
| 1639 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
|
| 1640 |
-
return x
|
| 1641 |
-
|
| 1642 |
-
def _forward_single(self,
|
| 1643 |
-
module,
|
| 1644 |
-
x,
|
| 1645 |
-
e,
|
| 1646 |
-
context,
|
| 1647 |
-
time_rel_pos_bias,
|
| 1648 |
-
focus_present_mask,
|
| 1649 |
-
video_mask,
|
| 1650 |
-
reference=None):
|
| 1651 |
-
if isinstance(module, ResidualBlock):
|
| 1652 |
-
module = checkpoint_wrapper(
|
| 1653 |
-
module) if self.use_checkpoint else module
|
| 1654 |
-
x = x.contiguous()
|
| 1655 |
-
x = module(x, e, reference)
|
| 1656 |
-
elif isinstance(module, ResBlock):
|
| 1657 |
-
module = checkpoint_wrapper(
|
| 1658 |
-
module) if self.use_checkpoint else module
|
| 1659 |
-
x = x.contiguous()
|
| 1660 |
-
x = module(x, e, self.batch)
|
| 1661 |
-
elif isinstance(module, SpatialTransformer):
|
| 1662 |
-
module = checkpoint_wrapper(
|
| 1663 |
-
module) if self.use_checkpoint else module
|
| 1664 |
-
x = module(x, context)
|
| 1665 |
-
elif isinstance(module, TemporalTransformer):
|
| 1666 |
-
module = checkpoint_wrapper(
|
| 1667 |
-
module) if self.use_checkpoint else module
|
| 1668 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1669 |
-
x = module(x, context)
|
| 1670 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1671 |
-
elif isinstance(module, CrossAttention):
|
| 1672 |
-
module = checkpoint_wrapper(
|
| 1673 |
-
module) if self.use_checkpoint else module
|
| 1674 |
-
x = module(x, context)
|
| 1675 |
-
elif isinstance(module, MemoryEfficientCrossAttention):
|
| 1676 |
-
module = checkpoint_wrapper(
|
| 1677 |
-
module) if self.use_checkpoint else module
|
| 1678 |
-
x = module(x, context)
|
| 1679 |
-
elif isinstance(module, BasicTransformerBlock):
|
| 1680 |
-
module = checkpoint_wrapper(
|
| 1681 |
-
module) if self.use_checkpoint else module
|
| 1682 |
-
x = module(x, context)
|
| 1683 |
-
elif isinstance(module, FeedForward):
|
| 1684 |
-
x = module(x, context)
|
| 1685 |
-
elif isinstance(module, Upsample):
|
| 1686 |
-
x = module(x)
|
| 1687 |
-
elif isinstance(module, Downsample):
|
| 1688 |
-
x = module(x)
|
| 1689 |
-
elif isinstance(module, Resample):
|
| 1690 |
-
x = module(x, reference)
|
| 1691 |
-
elif isinstance(module, TemporalAttentionBlock):
|
| 1692 |
-
module = checkpoint_wrapper(
|
| 1693 |
-
module) if self.use_checkpoint else module
|
| 1694 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1695 |
-
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
| 1696 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1697 |
-
elif isinstance(module, TemporalAttentionMultiBlock):
|
| 1698 |
-
module = checkpoint_wrapper(
|
| 1699 |
-
module) if self.use_checkpoint else module
|
| 1700 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1701 |
-
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
| 1702 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1703 |
-
elif isinstance(module, InitTemporalConvBlock):
|
| 1704 |
-
module = checkpoint_wrapper(
|
| 1705 |
-
module) if self.use_checkpoint else module
|
| 1706 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1707 |
-
x = module(x)
|
| 1708 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1709 |
-
elif isinstance(module, TemporalConvBlock):
|
| 1710 |
-
module = checkpoint_wrapper(
|
| 1711 |
-
module) if self.use_checkpoint else module
|
| 1712 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1713 |
-
x = module(x)
|
| 1714 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1715 |
-
elif isinstance(module, nn.ModuleList):
|
| 1716 |
-
for block in module:
|
| 1717 |
-
x = self._forward_single(block, x, e, context,
|
| 1718 |
-
time_rel_pos_bias, focus_present_mask,
|
| 1719 |
-
video_mask, reference)
|
| 1720 |
-
else:
|
| 1721 |
-
x = module(x)
|
| 1722 |
-
return x
|
| 1723 |
-
|
| 1724 |
-
|
| 1725 |
-
class ControlledV2VUNet(Vid2VidSDUNet):
|
| 1726 |
-
def __init__(self):
|
| 1727 |
-
super(ControlledV2VUNet, self).__init__()
|
| 1728 |
-
self.VideoControlNet = VideoControlNet()
|
| 1729 |
-
|
| 1730 |
-
def forward(self,
|
| 1731 |
-
x,
|
| 1732 |
-
t,
|
| 1733 |
-
y,
|
| 1734 |
-
hint=None,
|
| 1735 |
-
variant_info=None,
|
| 1736 |
-
hint_chunk=None,
|
| 1737 |
-
t_hint=None,
|
| 1738 |
-
s_cond=None,
|
| 1739 |
-
mask_cond=None,
|
| 1740 |
-
x_lr=None,
|
| 1741 |
-
fps=None,
|
| 1742 |
-
mask=None,
|
| 1743 |
-
video_mask=None,
|
| 1744 |
-
focus_present_mask=None,
|
| 1745 |
-
prob_focus_present=0.,
|
| 1746 |
-
mask_last_frame_num=0,
|
| 1747 |
-
):
|
| 1748 |
-
|
| 1749 |
-
batch, _, f, _, _= x.shape
|
| 1750 |
-
device = x.device
|
| 1751 |
-
self.batch = batch
|
| 1752 |
-
|
| 1753 |
-
# Process text (new added for t5 encoder)
|
| 1754 |
-
# y = self.VideoControlNet.y_embedder(y, self.training).squeeze(1) # [1, 1, 120, 4096] -> [B, 1, 120, 1024].squeeze(1) -> [B, 120, 1024]
|
| 1755 |
-
|
| 1756 |
-
if hint_chunk is not None:
|
| 1757 |
-
hint = hint_chunk
|
| 1758 |
-
|
| 1759 |
-
control = self.VideoControlNet(x, t, y, hint=hint, t_hint=t_hint, \
|
| 1760 |
-
mask_cond=mask_cond, s_cond=s_cond, \
|
| 1761 |
-
variant_info=variant_info)
|
| 1762 |
-
|
| 1763 |
-
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
|
| 1764 |
-
if mask_last_frame_num > 0:
|
| 1765 |
-
focus_present_mask = None
|
| 1766 |
-
video_mask[-mask_last_frame_num:] = False
|
| 1767 |
-
else:
|
| 1768 |
-
focus_present_mask = default(
|
| 1769 |
-
focus_present_mask, lambda: prob_mask_like(
|
| 1770 |
-
(batch, ), prob_focus_present, device=device))
|
| 1771 |
-
|
| 1772 |
-
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
|
| 1773 |
-
time_rel_pos_bias = self.time_rel_pos_bias(
|
| 1774 |
-
x.shape[2], device=x.device)
|
| 1775 |
-
else:
|
| 1776 |
-
time_rel_pos_bias = None
|
| 1777 |
-
|
| 1778 |
-
e = self.time_embed(sinusoidal_embedding(t, self.dim))
|
| 1779 |
-
e = e.repeat_interleave(repeats=f, dim=0)
|
| 1780 |
-
|
| 1781 |
-
# context = y
|
| 1782 |
-
context = y.repeat_interleave(repeats=f, dim=0)
|
| 1783 |
-
|
| 1784 |
-
# always in shape (b f) c h w, except for temporal layer
|
| 1785 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1786 |
-
# encoder
|
| 1787 |
-
xs = []
|
| 1788 |
-
for block in self.input_blocks:
|
| 1789 |
-
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
| 1790 |
-
focus_present_mask, video_mask, variant_info=variant_info)
|
| 1791 |
-
xs.append(x)
|
| 1792 |
-
# middle
|
| 1793 |
-
for block in self.middle_block:
|
| 1794 |
-
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
| 1795 |
-
focus_present_mask, video_mask, variant_info=variant_info)
|
| 1796 |
-
|
| 1797 |
-
if control is not None:
|
| 1798 |
-
x = control.pop() + x
|
| 1799 |
-
|
| 1800 |
-
# decoder
|
| 1801 |
-
for block in self.output_blocks:
|
| 1802 |
-
if control is None:
|
| 1803 |
-
x = torch.cat([x, xs.pop()], dim=1)
|
| 1804 |
-
else:
|
| 1805 |
-
x = torch.cat([x, xs.pop() + control.pop()], dim=1)
|
| 1806 |
-
x = self._forward_single(
|
| 1807 |
-
block,
|
| 1808 |
-
x,
|
| 1809 |
-
e,
|
| 1810 |
-
context,
|
| 1811 |
-
time_rel_pos_bias,
|
| 1812 |
-
focus_present_mask,
|
| 1813 |
-
video_mask,
|
| 1814 |
-
reference=xs[-1] if len(xs) > 0 else None,
|
| 1815 |
-
variant_info=variant_info)
|
| 1816 |
-
|
| 1817 |
-
# head
|
| 1818 |
-
x = self.out(x)
|
| 1819 |
-
|
| 1820 |
-
# reshape back to (b c f h w)
|
| 1821 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
|
| 1822 |
-
return x
|
| 1823 |
-
|
| 1824 |
-
def _forward_single(self,
|
| 1825 |
-
module,
|
| 1826 |
-
x,
|
| 1827 |
-
e,
|
| 1828 |
-
context,
|
| 1829 |
-
time_rel_pos_bias,
|
| 1830 |
-
focus_present_mask,
|
| 1831 |
-
video_mask,
|
| 1832 |
-
reference=None,
|
| 1833 |
-
variant_info=None):
|
| 1834 |
-
variant_info = None # For Debug
|
| 1835 |
-
if isinstance(module, ResidualBlock):
|
| 1836 |
-
module = checkpoint_wrapper(
|
| 1837 |
-
module) if self.use_checkpoint else module
|
| 1838 |
-
x = x.contiguous()
|
| 1839 |
-
x = module(x, e, reference)
|
| 1840 |
-
elif isinstance(module, ResBlock):
|
| 1841 |
-
module = checkpoint_wrapper(
|
| 1842 |
-
module) if self.use_checkpoint else module
|
| 1843 |
-
x = x.contiguous()
|
| 1844 |
-
x = module(x, e, self.batch, variant_info)
|
| 1845 |
-
elif isinstance(module, SpatialTransformer):
|
| 1846 |
-
module = checkpoint_wrapper(
|
| 1847 |
-
module) if self.use_checkpoint else module
|
| 1848 |
-
x = module(x, context)
|
| 1849 |
-
elif isinstance(module, TemporalTransformer):
|
| 1850 |
-
module = checkpoint_wrapper(
|
| 1851 |
-
module) if self.use_checkpoint else module
|
| 1852 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1853 |
-
x = module(x, context)
|
| 1854 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1855 |
-
elif isinstance(module, CrossAttention):
|
| 1856 |
-
module = checkpoint_wrapper(
|
| 1857 |
-
module) if self.use_checkpoint else module
|
| 1858 |
-
x = module(x, context)
|
| 1859 |
-
elif isinstance(module, MemoryEfficientCrossAttention):
|
| 1860 |
-
module = checkpoint_wrapper(
|
| 1861 |
-
module) if self.use_checkpoint else module
|
| 1862 |
-
x = module(x, context)
|
| 1863 |
-
elif isinstance(module, BasicTransformerBlock):
|
| 1864 |
-
module = checkpoint_wrapper(
|
| 1865 |
-
module) if self.use_checkpoint else module
|
| 1866 |
-
x = module(x, context)
|
| 1867 |
-
elif isinstance(module, FeedForward):
|
| 1868 |
-
x = module(x, context)
|
| 1869 |
-
elif isinstance(module, Upsample):
|
| 1870 |
-
x = module(x)
|
| 1871 |
-
elif isinstance(module, Downsample):
|
| 1872 |
-
x = module(x)
|
| 1873 |
-
elif isinstance(module, Resample):
|
| 1874 |
-
x = module(x, reference)
|
| 1875 |
-
elif isinstance(module, TemporalAttentionBlock):
|
| 1876 |
-
module = checkpoint_wrapper(
|
| 1877 |
-
module) if self.use_checkpoint else module
|
| 1878 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1879 |
-
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
| 1880 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1881 |
-
elif isinstance(module, TemporalAttentionMultiBlock):
|
| 1882 |
-
module = checkpoint_wrapper(
|
| 1883 |
-
module) if self.use_checkpoint else module
|
| 1884 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1885 |
-
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
| 1886 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1887 |
-
elif isinstance(module, InitTemporalConvBlock):
|
| 1888 |
-
module = checkpoint_wrapper(
|
| 1889 |
-
module) if self.use_checkpoint else module
|
| 1890 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1891 |
-
x = module(x)
|
| 1892 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1893 |
-
elif isinstance(module, TemporalConvBlock):
|
| 1894 |
-
module = checkpoint_wrapper(
|
| 1895 |
-
module) if self.use_checkpoint else module
|
| 1896 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 1897 |
-
x = module(x)
|
| 1898 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 1899 |
-
elif isinstance(module, nn.ModuleList):
|
| 1900 |
-
for block in module:
|
| 1901 |
-
x = self._forward_single(block, x, e, context,
|
| 1902 |
-
time_rel_pos_bias, focus_present_mask,
|
| 1903 |
-
video_mask, reference, variant_info)
|
| 1904 |
-
else:
|
| 1905 |
-
x = module(x)
|
| 1906 |
-
return x
|
| 1907 |
-
|
| 1908 |
-
|
| 1909 |
-
class VideoControlNet(nn.Module):
|
| 1910 |
-
|
| 1911 |
-
def __init__(self,
|
| 1912 |
-
in_dim=4,
|
| 1913 |
-
dim=320,
|
| 1914 |
-
y_dim=1024,
|
| 1915 |
-
context_dim=1024,
|
| 1916 |
-
out_dim=4,
|
| 1917 |
-
dim_mult=[1, 2, 4, 4],
|
| 1918 |
-
num_heads=8,
|
| 1919 |
-
head_dim=64,
|
| 1920 |
-
num_res_blocks=2,
|
| 1921 |
-
attn_scales=[1 / 1, 1 / 2, 1 / 4],
|
| 1922 |
-
use_scale_shift_norm=True,
|
| 1923 |
-
dropout=0.1,
|
| 1924 |
-
temporal_attn_times=1,
|
| 1925 |
-
temporal_attention=True,
|
| 1926 |
-
use_checkpoint=True,
|
| 1927 |
-
use_image_dataset=False,
|
| 1928 |
-
use_fps_condition=False,
|
| 1929 |
-
use_sim_mask=False,
|
| 1930 |
-
training=False,
|
| 1931 |
-
inpainting=True):
|
| 1932 |
-
embed_dim = dim * 4
|
| 1933 |
-
num_heads = num_heads if num_heads else dim // 32
|
| 1934 |
-
super(VideoControlNet, self).__init__()
|
| 1935 |
-
self.in_dim = in_dim
|
| 1936 |
-
self.dim = dim
|
| 1937 |
-
self.y_dim = y_dim
|
| 1938 |
-
self.context_dim = context_dim
|
| 1939 |
-
self.embed_dim = embed_dim
|
| 1940 |
-
self.out_dim = out_dim
|
| 1941 |
-
self.dim_mult = dim_mult
|
| 1942 |
-
# for temporal attention
|
| 1943 |
-
self.num_heads = num_heads
|
| 1944 |
-
# for spatial attention
|
| 1945 |
-
self.head_dim = head_dim
|
| 1946 |
-
self.num_res_blocks = num_res_blocks
|
| 1947 |
-
self.attn_scales = attn_scales
|
| 1948 |
-
self.use_scale_shift_norm = use_scale_shift_norm
|
| 1949 |
-
self.temporal_attn_times = temporal_attn_times
|
| 1950 |
-
self.temporal_attention = temporal_attention
|
| 1951 |
-
self.use_checkpoint = use_checkpoint
|
| 1952 |
-
self.use_image_dataset = use_image_dataset
|
| 1953 |
-
self.use_fps_condition = use_fps_condition
|
| 1954 |
-
self.use_sim_mask = use_sim_mask
|
| 1955 |
-
self.training = training
|
| 1956 |
-
self.inpainting = inpainting
|
| 1957 |
-
|
| 1958 |
-
use_linear_in_temporal = False
|
| 1959 |
-
transformer_depth = 1
|
| 1960 |
-
disabled_sa = False
|
| 1961 |
-
# params
|
| 1962 |
-
enc_dims = [dim * u for u in [1] + dim_mult]
|
| 1963 |
-
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 1964 |
-
shortcut_dims = []
|
| 1965 |
-
scale = 1.0
|
| 1966 |
-
|
| 1967 |
-
# CaptionEmbedder (new add)
|
| 1968 |
-
# approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 1969 |
-
# self.y_embedder = CaptionEmbedder(
|
| 1970 |
-
# in_channels=4096,
|
| 1971 |
-
# hidden_size=1024,
|
| 1972 |
-
# uncond_prob=0.1,
|
| 1973 |
-
# act_layer=approx_gelu,
|
| 1974 |
-
# token_num=120,
|
| 1975 |
-
# )
|
| 1976 |
-
|
| 1977 |
-
# embeddings
|
| 1978 |
-
self.time_embed = nn.Sequential(
|
| 1979 |
-
nn.Linear(dim, embed_dim), nn.SiLU(),
|
| 1980 |
-
nn.Linear(embed_dim, embed_dim))
|
| 1981 |
-
|
| 1982 |
-
# self.hint_time_zero_linear = zero_module(nn.Linear(embed_dim, embed_dim))
|
| 1983 |
-
|
| 1984 |
-
# scale prompt
|
| 1985 |
-
# self.scale_cond = nn.Sequential(
|
| 1986 |
-
# nn.Linear(dim, embed_dim), nn.SiLU(),
|
| 1987 |
-
# zero_module(nn.Linear(embed_dim, embed_dim)))
|
| 1988 |
-
|
| 1989 |
-
if self.use_fps_condition:
|
| 1990 |
-
self.fps_embedding = nn.Sequential(
|
| 1991 |
-
nn.Linear(dim, embed_dim), nn.SiLU(),
|
| 1992 |
-
nn.Linear(embed_dim, embed_dim))
|
| 1993 |
-
nn.init.zeros_(self.fps_embedding[-1].weight)
|
| 1994 |
-
nn.init.zeros_(self.fps_embedding[-1].bias)
|
| 1995 |
-
|
| 1996 |
-
# encoder
|
| 1997 |
-
self.input_blocks = nn.ModuleList()
|
| 1998 |
-
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
| 1999 |
-
# need an initial temporal attention?
|
| 2000 |
-
if temporal_attention:
|
| 2001 |
-
if USE_TEMPORAL_TRANSFORMER:
|
| 2002 |
-
init_block.append(
|
| 2003 |
-
TemporalTransformer(
|
| 2004 |
-
dim,
|
| 2005 |
-
num_heads,
|
| 2006 |
-
head_dim,
|
| 2007 |
-
depth=transformer_depth,
|
| 2008 |
-
context_dim=context_dim,
|
| 2009 |
-
disable_self_attn=disabled_sa,
|
| 2010 |
-
use_linear=use_linear_in_temporal,
|
| 2011 |
-
multiply_zero=use_image_dataset,
|
| 2012 |
-
is_ctrl=True,))
|
| 2013 |
-
else:
|
| 2014 |
-
init_block.append(
|
| 2015 |
-
TemporalAttentionMultiBlock(
|
| 2016 |
-
dim,
|
| 2017 |
-
num_heads,
|
| 2018 |
-
head_dim,
|
| 2019 |
-
rotary_emb=self.rotary_emb,
|
| 2020 |
-
temporal_attn_times=temporal_attn_times,
|
| 2021 |
-
use_image_dataset=use_image_dataset))
|
| 2022 |
-
self.input_blocks.append(init_block)
|
| 2023 |
-
self.zero_convs = nn.ModuleList([self.make_zero_conv(dim)])
|
| 2024 |
-
shortcut_dims.append(dim)
|
| 2025 |
-
for i, (in_dim,
|
| 2026 |
-
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
| 2027 |
-
for j in range(num_res_blocks):
|
| 2028 |
-
block = nn.ModuleList([
|
| 2029 |
-
ResBlock(
|
| 2030 |
-
in_dim,
|
| 2031 |
-
embed_dim,
|
| 2032 |
-
dropout,
|
| 2033 |
-
out_channels=out_dim,
|
| 2034 |
-
use_scale_shift_norm=False,
|
| 2035 |
-
use_image_dataset=use_image_dataset,
|
| 2036 |
-
)
|
| 2037 |
-
])
|
| 2038 |
-
if scale in attn_scales:
|
| 2039 |
-
block.append(
|
| 2040 |
-
SpatialTransformer(
|
| 2041 |
-
out_dim,
|
| 2042 |
-
out_dim // head_dim,
|
| 2043 |
-
head_dim,
|
| 2044 |
-
depth=1,
|
| 2045 |
-
context_dim=self.context_dim,
|
| 2046 |
-
disable_self_attn=False,
|
| 2047 |
-
use_linear=True,
|
| 2048 |
-
is_ctrl=True))
|
| 2049 |
-
if self.temporal_attention:
|
| 2050 |
-
if USE_TEMPORAL_TRANSFORMER:
|
| 2051 |
-
block.append(
|
| 2052 |
-
TemporalTransformer(
|
| 2053 |
-
out_dim,
|
| 2054 |
-
out_dim // head_dim,
|
| 2055 |
-
head_dim,
|
| 2056 |
-
depth=transformer_depth,
|
| 2057 |
-
context_dim=context_dim,
|
| 2058 |
-
disable_self_attn=disabled_sa,
|
| 2059 |
-
use_linear=use_linear_in_temporal,
|
| 2060 |
-
multiply_zero=use_image_dataset,
|
| 2061 |
-
is_ctrl=True,))
|
| 2062 |
-
else:
|
| 2063 |
-
block.append(
|
| 2064 |
-
TemporalAttentionMultiBlock(
|
| 2065 |
-
out_dim,
|
| 2066 |
-
num_heads,
|
| 2067 |
-
head_dim,
|
| 2068 |
-
rotary_emb=self.rotary_emb,
|
| 2069 |
-
use_image_dataset=use_image_dataset,
|
| 2070 |
-
use_sim_mask=use_sim_mask,
|
| 2071 |
-
temporal_attn_times=temporal_attn_times))
|
| 2072 |
-
in_dim = out_dim
|
| 2073 |
-
self.input_blocks.append(block)
|
| 2074 |
-
self.zero_convs.append(self.make_zero_conv(out_dim))
|
| 2075 |
-
shortcut_dims.append(out_dim)
|
| 2076 |
-
|
| 2077 |
-
# downsample
|
| 2078 |
-
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
| 2079 |
-
downsample = Downsample(
|
| 2080 |
-
out_dim, True, dims=2, out_channels=out_dim)
|
| 2081 |
-
shortcut_dims.append(out_dim)
|
| 2082 |
-
scale /= 2.0
|
| 2083 |
-
self.input_blocks.append(downsample)
|
| 2084 |
-
self.zero_convs.append(self.make_zero_conv(out_dim))
|
| 2085 |
-
|
| 2086 |
-
self.middle_block = nn.ModuleList([
|
| 2087 |
-
ResBlock(
|
| 2088 |
-
out_dim,
|
| 2089 |
-
embed_dim,
|
| 2090 |
-
dropout,
|
| 2091 |
-
use_scale_shift_norm=False,
|
| 2092 |
-
use_image_dataset=use_image_dataset,
|
| 2093 |
-
),
|
| 2094 |
-
SpatialTransformer(
|
| 2095 |
-
out_dim,
|
| 2096 |
-
out_dim // head_dim,
|
| 2097 |
-
head_dim,
|
| 2098 |
-
depth=1,
|
| 2099 |
-
context_dim=self.context_dim,
|
| 2100 |
-
disable_self_attn=False,
|
| 2101 |
-
use_linear=True,
|
| 2102 |
-
is_ctrl=True)
|
| 2103 |
-
])
|
| 2104 |
-
|
| 2105 |
-
if self.temporal_attention:
|
| 2106 |
-
if USE_TEMPORAL_TRANSFORMER:
|
| 2107 |
-
self.middle_block.append(
|
| 2108 |
-
TemporalTransformer(
|
| 2109 |
-
out_dim,
|
| 2110 |
-
out_dim // head_dim,
|
| 2111 |
-
head_dim,
|
| 2112 |
-
depth=transformer_depth,
|
| 2113 |
-
context_dim=context_dim,
|
| 2114 |
-
disable_self_attn=disabled_sa,
|
| 2115 |
-
use_linear=use_linear_in_temporal,
|
| 2116 |
-
multiply_zero=use_image_dataset,
|
| 2117 |
-
is_ctrl=True,
|
| 2118 |
-
))
|
| 2119 |
-
else:
|
| 2120 |
-
self.middle_block.append(
|
| 2121 |
-
TemporalAttentionMultiBlock(
|
| 2122 |
-
out_dim,
|
| 2123 |
-
num_heads,
|
| 2124 |
-
head_dim,
|
| 2125 |
-
rotary_emb=self.rotary_emb,
|
| 2126 |
-
use_image_dataset=use_image_dataset,
|
| 2127 |
-
use_sim_mask=use_sim_mask,
|
| 2128 |
-
temporal_attn_times=temporal_attn_times))
|
| 2129 |
-
|
| 2130 |
-
self.middle_block.append(
|
| 2131 |
-
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
|
| 2132 |
-
|
| 2133 |
-
self.middle_block_out = self.make_zero_conv(embed_dim)
|
| 2134 |
-
|
| 2135 |
-
'''
|
| 2136 |
-
add prompt
|
| 2137 |
-
'''
|
| 2138 |
-
add_dim = 320
|
| 2139 |
-
self.add_dim = add_dim
|
| 2140 |
-
|
| 2141 |
-
self.input_hint_block = zero_module(nn.Conv2d(4, add_dim, 3, padding=1))
|
| 2142 |
-
|
| 2143 |
-
def make_zero_conv(self, in_channels, out_channels=None):
|
| 2144 |
-
out_channels = in_channels if out_channels is None else out_channels
|
| 2145 |
-
return TimestepEmbedSequential(zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)))
|
| 2146 |
-
|
| 2147 |
-
def forward(self,
|
| 2148 |
-
x,
|
| 2149 |
-
t,
|
| 2150 |
-
y,
|
| 2151 |
-
s_cond=None,
|
| 2152 |
-
hint=None,
|
| 2153 |
-
variant_info=None,
|
| 2154 |
-
t_hint=None,
|
| 2155 |
-
mask_cond=None,
|
| 2156 |
-
fps=None,
|
| 2157 |
-
video_mask=None,
|
| 2158 |
-
focus_present_mask=None,
|
| 2159 |
-
prob_focus_present=0.,
|
| 2160 |
-
mask_last_frame_num=0):
|
| 2161 |
-
|
| 2162 |
-
batch, _, f, _, _ = x.shape
|
| 2163 |
-
device = x.device
|
| 2164 |
-
self.batch = batch
|
| 2165 |
-
|
| 2166 |
-
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
|
| 2167 |
-
if mask_last_frame_num > 0:
|
| 2168 |
-
focus_present_mask = None
|
| 2169 |
-
video_mask[-mask_last_frame_num:] = False
|
| 2170 |
-
else:
|
| 2171 |
-
focus_present_mask = default(
|
| 2172 |
-
focus_present_mask, lambda: prob_mask_like(
|
| 2173 |
-
(batch, ), prob_focus_present, device=device))
|
| 2174 |
-
|
| 2175 |
-
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
|
| 2176 |
-
time_rel_pos_bias = self.time_rel_pos_bias(
|
| 2177 |
-
x.shape[2], device=x.device)
|
| 2178 |
-
else:
|
| 2179 |
-
time_rel_pos_bias = None
|
| 2180 |
-
|
| 2181 |
-
if hint is not None:
|
| 2182 |
-
# add = x.new_zeros(batch, self.add_dim, f, h, w)
|
| 2183 |
-
hint = rearrange(hint, 'b c f h w -> (b f) c h w')
|
| 2184 |
-
hint = self.input_hint_block(hint)
|
| 2185 |
-
# hint = rearrange(hint, '(b f) c h w -> b c f h w', b = batch)
|
| 2186 |
-
|
| 2187 |
-
e = self.time_embed(sinusoidal_embedding(t, self.dim))
|
| 2188 |
-
e = e.repeat_interleave(repeats=f, dim=0)
|
| 2189 |
-
|
| 2190 |
-
context = y.repeat_interleave(repeats=f, dim=0)
|
| 2191 |
-
|
| 2192 |
-
# always in shape (b f) c h w, except for temporal layer
|
| 2193 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 2194 |
-
# print('before x shape:', x.shape) [64, 320, 90, 160]
|
| 2195 |
-
# print('hint shape:', hint.shape) [32, 320, 90, 160]
|
| 2196 |
-
|
| 2197 |
-
# encoder
|
| 2198 |
-
xs = []
|
| 2199 |
-
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
| 2200 |
-
if hint is not None:
|
| 2201 |
-
for block in module:
|
| 2202 |
-
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
| 2203 |
-
focus_present_mask, video_mask, variant_info=variant_info)
|
| 2204 |
-
if not isinstance(block, TemporalTransformer):
|
| 2205 |
-
if hint is not None:
|
| 2206 |
-
x += hint
|
| 2207 |
-
hint = None
|
| 2208 |
-
else:
|
| 2209 |
-
x = self._forward_single(module, x, e, context, time_rel_pos_bias,
|
| 2210 |
-
focus_present_mask, video_mask, variant_info=variant_info)
|
| 2211 |
-
xs.append(zero_conv(x, e, context))
|
| 2212 |
-
|
| 2213 |
-
# middle
|
| 2214 |
-
for block in self.middle_block:
|
| 2215 |
-
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
|
| 2216 |
-
focus_present_mask, video_mask, variant_info=variant_info)
|
| 2217 |
-
xs.append(self.middle_block_out(x, e, context))
|
| 2218 |
-
|
| 2219 |
-
return xs
|
| 2220 |
-
|
| 2221 |
-
def _forward_single(self,
|
| 2222 |
-
module,
|
| 2223 |
-
x,
|
| 2224 |
-
e,
|
| 2225 |
-
context,
|
| 2226 |
-
time_rel_pos_bias,
|
| 2227 |
-
focus_present_mask,
|
| 2228 |
-
video_mask,
|
| 2229 |
-
reference=None,
|
| 2230 |
-
variant_info=None,):
|
| 2231 |
-
# variant_info = None # For Debug
|
| 2232 |
-
if isinstance(module, ResidualBlock):
|
| 2233 |
-
module = checkpoint_wrapper(
|
| 2234 |
-
module) if self.use_checkpoint else module
|
| 2235 |
-
x = x.contiguous()
|
| 2236 |
-
x = module(x, e, reference)
|
| 2237 |
-
elif isinstance(module, ResBlock):
|
| 2238 |
-
module = checkpoint_wrapper(
|
| 2239 |
-
module) if self.use_checkpoint else module
|
| 2240 |
-
x = x.contiguous()
|
| 2241 |
-
x = module(x, e, self.batch, variant_info)
|
| 2242 |
-
elif isinstance(module, SpatialTransformer):
|
| 2243 |
-
module = checkpoint_wrapper(
|
| 2244 |
-
module) if self.use_checkpoint else module
|
| 2245 |
-
x = module(x, context)
|
| 2246 |
-
elif isinstance(module, TemporalTransformer):
|
| 2247 |
-
module = checkpoint_wrapper(
|
| 2248 |
-
module) if self.use_checkpoint else module
|
| 2249 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 2250 |
-
# print("x shape:", x.shape) # [2, 320, 32, 90, 160]
|
| 2251 |
-
x = module(x, context)
|
| 2252 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 2253 |
-
elif isinstance(module, CrossAttention):
|
| 2254 |
-
module = checkpoint_wrapper(
|
| 2255 |
-
module) if self.use_checkpoint else module
|
| 2256 |
-
x = module(x, context)
|
| 2257 |
-
elif isinstance(module, MemoryEfficientCrossAttention):
|
| 2258 |
-
module = checkpoint_wrapper(
|
| 2259 |
-
module) if self.use_checkpoint else module
|
| 2260 |
-
x = module(x, context)
|
| 2261 |
-
elif isinstance(module, BasicTransformerBlock):
|
| 2262 |
-
module = checkpoint_wrapper(
|
| 2263 |
-
module) if self.use_checkpoint else module
|
| 2264 |
-
x = module(x, context)
|
| 2265 |
-
elif isinstance(module, FeedForward):
|
| 2266 |
-
x = module(x, context)
|
| 2267 |
-
elif isinstance(module, Upsample):
|
| 2268 |
-
x = module(x)
|
| 2269 |
-
elif isinstance(module, Downsample):
|
| 2270 |
-
x = module(x)
|
| 2271 |
-
elif isinstance(module, Resample):
|
| 2272 |
-
x = module(x, reference)
|
| 2273 |
-
elif isinstance(module, TemporalAttentionBlock):
|
| 2274 |
-
module = checkpoint_wrapper(
|
| 2275 |
-
module) if self.use_checkpoint else module
|
| 2276 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 2277 |
-
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
| 2278 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 2279 |
-
elif isinstance(module, TemporalAttentionMultiBlock):
|
| 2280 |
-
module = checkpoint_wrapper(
|
| 2281 |
-
module) if self.use_checkpoint else module
|
| 2282 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 2283 |
-
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
|
| 2284 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 2285 |
-
elif isinstance(module, InitTemporalConvBlock):
|
| 2286 |
-
module = checkpoint_wrapper(
|
| 2287 |
-
module) if self.use_checkpoint else module
|
| 2288 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 2289 |
-
x = module(x)
|
| 2290 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 2291 |
-
elif isinstance(module, TemporalConvBlock):
|
| 2292 |
-
module = checkpoint_wrapper(
|
| 2293 |
-
module) if self.use_checkpoint else module
|
| 2294 |
-
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
|
| 2295 |
-
x = module(x)
|
| 2296 |
-
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 2297 |
-
elif isinstance(module, nn.ModuleList):
|
| 2298 |
-
for block in module:
|
| 2299 |
-
x = self._forward_single(block, x, e, context,
|
| 2300 |
-
time_rel_pos_bias, focus_present_mask,
|
| 2301 |
-
video_mask, reference, variant_info)
|
| 2302 |
-
else:
|
| 2303 |
-
x = module(x)
|
| 2304 |
-
return x
|
| 2305 |
-
|
| 2306 |
-
|
| 2307 |
-
class TimestepBlock(nn.Module):
|
| 2308 |
-
"""
|
| 2309 |
-
Any module where forward() takes timestep embeddings as a second argument.
|
| 2310 |
-
"""
|
| 2311 |
-
|
| 2312 |
-
@abstractmethod
|
| 2313 |
-
def forward(self, x, emb):
|
| 2314 |
-
"""
|
| 2315 |
-
Apply the module to `x` given `emb` timestep embeddings.
|
| 2316 |
-
"""
|
| 2317 |
-
|
| 2318 |
-
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 2319 |
-
"""
|
| 2320 |
-
A sequential module that passes timestep embeddings to the children that
|
| 2321 |
-
support it as an extra input.
|
| 2322 |
-
"""
|
| 2323 |
-
|
| 2324 |
-
def forward(self, x, emb, context=None):
|
| 2325 |
-
for layer in self:
|
| 2326 |
-
if isinstance(layer, TimestepBlock):
|
| 2327 |
-
x = layer(x, emb)
|
| 2328 |
-
elif isinstance(layer, SpatialTransformer):
|
| 2329 |
-
x = layer(x, context)
|
| 2330 |
-
else:
|
| 2331 |
-
x = layer(x)
|
| 2332 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/utils/__init__.py
DELETED
|
File without changes
|
video_to_video/utils/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (159 Bytes)
|
|
|
video_to_video/utils/__pycache__/config.cpython-39.pyc
DELETED
|
Binary file (3.44 kB)
|
|
|
video_to_video/utils/__pycache__/logger.cpython-39.pyc
DELETED
|
Binary file (2.14 kB)
|
|
|
video_to_video/utils/__pycache__/seed.cpython-39.pyc
DELETED
|
Binary file (467 Bytes)
|
|
|
video_to_video/utils/config.py
DELETED
|
@@ -1,169 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
import os
|
| 5 |
-
import os.path as osp
|
| 6 |
-
from datetime import datetime
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from easydict import EasyDict
|
| 10 |
-
|
| 11 |
-
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
|
| 12 |
-
|
| 13 |
-
# ---------------------------work dir--------------------------
|
| 14 |
-
cfg.work_dir = 'workspace/'
|
| 15 |
-
|
| 16 |
-
# ---------------------------Global Variable-----------------------------------
|
| 17 |
-
cfg.resolution = [448, 256]
|
| 18 |
-
cfg.max_frames = 32
|
| 19 |
-
# -----------------------------------------------------------------------------
|
| 20 |
-
|
| 21 |
-
# ---------------------------Dataset Parameter---------------------------------
|
| 22 |
-
cfg.mean = [0.5, 0.5, 0.5]
|
| 23 |
-
cfg.std = [0.5, 0.5, 0.5]
|
| 24 |
-
cfg.max_words = 1000
|
| 25 |
-
|
| 26 |
-
# PlaceHolder
|
| 27 |
-
cfg.vit_out_dim = 1024
|
| 28 |
-
cfg.vit_resolution = [224, 224]
|
| 29 |
-
cfg.depth_clamp = 10.0
|
| 30 |
-
cfg.misc_size = 384
|
| 31 |
-
cfg.depth_std = 20.0
|
| 32 |
-
|
| 33 |
-
cfg.frame_lens = 32
|
| 34 |
-
cfg.sample_fps = 8
|
| 35 |
-
|
| 36 |
-
cfg.batch_sizes = 1
|
| 37 |
-
# -----------------------------------------------------------------------------
|
| 38 |
-
|
| 39 |
-
# ---------------------------Mode Parameters-----------------------------------
|
| 40 |
-
# Diffusion
|
| 41 |
-
cfg.schedule = 'cosine'
|
| 42 |
-
cfg.num_timesteps = 1000
|
| 43 |
-
cfg.mean_type = 'v'
|
| 44 |
-
cfg.var_type = 'fixed_small'
|
| 45 |
-
cfg.loss_type = 'mse'
|
| 46 |
-
cfg.ddim_timesteps = 50
|
| 47 |
-
cfg.ddim_eta = 0.0
|
| 48 |
-
cfg.clamp = 1.0
|
| 49 |
-
cfg.share_noise = False
|
| 50 |
-
cfg.use_div_loss = False
|
| 51 |
-
cfg.noise_strength = 0.1
|
| 52 |
-
|
| 53 |
-
# classifier-free guidance
|
| 54 |
-
cfg.p_zero = 0.1
|
| 55 |
-
cfg.guide_scale = 3.0
|
| 56 |
-
|
| 57 |
-
# clip vision encoder
|
| 58 |
-
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
|
| 59 |
-
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
|
| 60 |
-
|
| 61 |
-
# Model
|
| 62 |
-
cfg.scale_factor = 0.18215
|
| 63 |
-
cfg.use_fp16 = True
|
| 64 |
-
cfg.temporal_attention = True
|
| 65 |
-
cfg.decoder_bs = 8
|
| 66 |
-
|
| 67 |
-
cfg.UNet = {
|
| 68 |
-
'type': 'Vid2VidSDUNet',
|
| 69 |
-
'in_dim': 4,
|
| 70 |
-
'dim': 320,
|
| 71 |
-
'y_dim': cfg.vit_out_dim,
|
| 72 |
-
'context_dim': 1024,
|
| 73 |
-
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
|
| 74 |
-
'dim_mult': [1, 2, 4, 4],
|
| 75 |
-
'num_heads': 8,
|
| 76 |
-
'head_dim': 64,
|
| 77 |
-
'num_res_blocks': 2,
|
| 78 |
-
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
|
| 79 |
-
'dropout': 0.1,
|
| 80 |
-
'temporal_attention': cfg.temporal_attention,
|
| 81 |
-
'temporal_attn_times': 1,
|
| 82 |
-
'use_checkpoint': False,
|
| 83 |
-
'use_fps_condition': False,
|
| 84 |
-
'use_sim_mask': False,
|
| 85 |
-
'num_tokens': 4,
|
| 86 |
-
'default_fps': 8,
|
| 87 |
-
'input_dim': 1024
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
cfg.guidances = []
|
| 91 |
-
|
| 92 |
-
# auotoencoder from stabel diffusion
|
| 93 |
-
cfg.auto_encoder = {
|
| 94 |
-
'type': 'AutoencoderKL',
|
| 95 |
-
'ddconfig': {
|
| 96 |
-
'double_z': True,
|
| 97 |
-
'z_channels': 4,
|
| 98 |
-
'resolution': 256,
|
| 99 |
-
'in_channels': 3,
|
| 100 |
-
'out_ch': 3,
|
| 101 |
-
'ch': 128,
|
| 102 |
-
'ch_mult': [1, 2, 4, 4],
|
| 103 |
-
'num_res_blocks': 2,
|
| 104 |
-
'attn_resolutions': [],
|
| 105 |
-
'dropout': 0.0
|
| 106 |
-
},
|
| 107 |
-
'embed_dim': 4,
|
| 108 |
-
'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
|
| 109 |
-
}
|
| 110 |
-
# clip embedder
|
| 111 |
-
cfg.embedder = {
|
| 112 |
-
'type': 'FrozenOpenCLIPEmbedder',
|
| 113 |
-
'layer': 'penultimate',
|
| 114 |
-
'vit_resolution': [224, 224],
|
| 115 |
-
'pretrained': 'open_clip_pytorch_model.bin'
|
| 116 |
-
}
|
| 117 |
-
# -----------------------------------------------------------------------------
|
| 118 |
-
|
| 119 |
-
# ---------------------------Training Settings---------------------------------
|
| 120 |
-
# training and optimizer
|
| 121 |
-
cfg.ema_decay = 0.9999
|
| 122 |
-
cfg.num_steps = 600000
|
| 123 |
-
cfg.lr = 5e-5
|
| 124 |
-
cfg.weight_decay = 0.0
|
| 125 |
-
cfg.betas = (0.9, 0.999)
|
| 126 |
-
cfg.eps = 1.0e-8
|
| 127 |
-
cfg.chunk_size = 16
|
| 128 |
-
cfg.alpha = 0.7
|
| 129 |
-
cfg.save_ckp_interval = 1000
|
| 130 |
-
# -----------------------------------------------------------------------------
|
| 131 |
-
|
| 132 |
-
# ----------------------------Pretrain Settings---------------------------------
|
| 133 |
-
# Default: load 2d pretrain
|
| 134 |
-
cfg.fix_weight = False
|
| 135 |
-
cfg.load_match = False
|
| 136 |
-
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
|
| 137 |
-
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
|
| 138 |
-
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
|
| 139 |
-
# -----------------------------------------------------------------------------
|
| 140 |
-
|
| 141 |
-
# -----------------------------Visual-------------------------------------------
|
| 142 |
-
# Visual videos
|
| 143 |
-
cfg.viz_interval = 1000
|
| 144 |
-
cfg.visual_train = {
|
| 145 |
-
'type': 'VisualVideoTextDuringTrain',
|
| 146 |
-
}
|
| 147 |
-
cfg.visual_inference = {
|
| 148 |
-
'type': 'VisualGeneratedVideos',
|
| 149 |
-
}
|
| 150 |
-
cfg.inference_list_path = ''
|
| 151 |
-
|
| 152 |
-
# logging
|
| 153 |
-
cfg.log_interval = 100
|
| 154 |
-
|
| 155 |
-
# Default log_dir
|
| 156 |
-
cfg.log_dir = 'workspace/output_data'
|
| 157 |
-
# -----------------------------------------------------------------------------
|
| 158 |
-
|
| 159 |
-
# ---------------------------Others--------------------------------------------
|
| 160 |
-
# seed
|
| 161 |
-
cfg.seed = 8888
|
| 162 |
-
|
| 163 |
-
cfg.negative_prompt = 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
|
| 164 |
-
CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
|
| 165 |
-
signature, jpeg artifacts, deformed, lowres, over-smooth'
|
| 166 |
-
|
| 167 |
-
cfg.positive_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
|
| 168 |
-
hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
|
| 169 |
-
skin pore detailing, hyper sharpness, perfect without deformations.'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/utils/logger.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
-
|
| 3 |
-
import importlib
|
| 4 |
-
import logging
|
| 5 |
-
from typing import Optional
|
| 6 |
-
from torch import distributed as dist
|
| 7 |
-
|
| 8 |
-
init_loggers = {}
|
| 9 |
-
|
| 10 |
-
formatter = logging.Formatter(
|
| 11 |
-
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_logger(log_file: Optional[str] = None,
|
| 15 |
-
log_level: int = logging.INFO,
|
| 16 |
-
file_mode: str = 'w'):
|
| 17 |
-
""" Get logging logger
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
log_file: Log filename, if specified, file handler will be added to
|
| 21 |
-
logger
|
| 22 |
-
log_level: Logging level.
|
| 23 |
-
file_mode: Specifies the mode to open the file, if filename is
|
| 24 |
-
specified (if filemode is unspecified, it defaults to 'w').
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
logger_name = __name__.split('.')[0]
|
| 28 |
-
logger = logging.getLogger(logger_name)
|
| 29 |
-
logger.propagate = False
|
| 30 |
-
if logger_name in init_loggers:
|
| 31 |
-
add_file_handler_if_needed(logger, log_file, file_mode, log_level)
|
| 32 |
-
return logger
|
| 33 |
-
|
| 34 |
-
# handle duplicate logs to the console
|
| 35 |
-
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
|
| 36 |
-
# to the root logger. As logger.propagate is True by default, this root
|
| 37 |
-
# level handler causes logging messages from rank>0 processes to
|
| 38 |
-
# unexpectedly show up on the console, creating much unwanted clutter.
|
| 39 |
-
# To fix this issue, we set the root logger's StreamHandler, if any, to log
|
| 40 |
-
# at the ERROR level.
|
| 41 |
-
for handler in logger.root.handlers:
|
| 42 |
-
if type(handler) is logging.StreamHandler:
|
| 43 |
-
handler.setLevel(logging.ERROR)
|
| 44 |
-
|
| 45 |
-
stream_handler = logging.StreamHandler()
|
| 46 |
-
handlers = [stream_handler]
|
| 47 |
-
|
| 48 |
-
if importlib.util.find_spec('torch') is not None:
|
| 49 |
-
is_worker0 = is_master()
|
| 50 |
-
else:
|
| 51 |
-
is_worker0 = True
|
| 52 |
-
|
| 53 |
-
if is_worker0 and log_file is not None:
|
| 54 |
-
file_handler = logging.FileHandler(log_file, file_mode)
|
| 55 |
-
handlers.append(file_handler)
|
| 56 |
-
|
| 57 |
-
for handler in handlers:
|
| 58 |
-
handler.setFormatter(formatter)
|
| 59 |
-
handler.setLevel(log_level)
|
| 60 |
-
logger.addHandler(handler)
|
| 61 |
-
|
| 62 |
-
if is_worker0:
|
| 63 |
-
logger.setLevel(log_level)
|
| 64 |
-
else:
|
| 65 |
-
logger.setLevel(logging.ERROR)
|
| 66 |
-
|
| 67 |
-
init_loggers[logger_name] = True
|
| 68 |
-
|
| 69 |
-
return logger
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
|
| 73 |
-
for handler in logger.handlers:
|
| 74 |
-
if isinstance(handler, logging.FileHandler):
|
| 75 |
-
return
|
| 76 |
-
|
| 77 |
-
if importlib.util.find_spec('torch') is not None:
|
| 78 |
-
is_worker0 = is_master()
|
| 79 |
-
else:
|
| 80 |
-
is_worker0 = True
|
| 81 |
-
|
| 82 |
-
if is_worker0 and log_file is not None:
|
| 83 |
-
file_handler = logging.FileHandler(log_file, file_mode)
|
| 84 |
-
file_handler.setFormatter(formatter)
|
| 85 |
-
file_handler.setLevel(log_level)
|
| 86 |
-
logger.addHandler(file_handler)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def is_master(group=None):
|
| 90 |
-
return dist.get_rank(group) == 0 if is_dist() else True
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def is_dist():
|
| 94 |
-
return dist.is_available() and dist.is_initialized()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/utils/seed.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
-
|
| 3 |
-
import random
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def setup_seed(seed):
|
| 10 |
-
torch.manual_seed(seed)
|
| 11 |
-
torch.cuda.manual_seed_all(seed)
|
| 12 |
-
np.random.seed(seed)
|
| 13 |
-
random.seed(seed)
|
| 14 |
-
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_to_video/video_to_video_model.py
DELETED
|
@@ -1,237 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import os.path as osp
|
| 3 |
-
import random
|
| 4 |
-
from typing import Any, Dict
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.cuda.amp as amp
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from video_to_video.modules import *
|
| 11 |
-
from video_to_video.utils.config import cfg
|
| 12 |
-
from video_to_video.diffusion.diffusion_sdedit import GaussianDiffusion
|
| 13 |
-
from video_to_video.diffusion.schedules_sdedit import noise_schedule
|
| 14 |
-
from video_to_video.utils.logger import get_logger
|
| 15 |
-
|
| 16 |
-
from diffusers import AutoencoderKLTemporalDecoder
|
| 17 |
-
import requests
|
| 18 |
-
|
| 19 |
-
def download_model(url, model_path):
|
| 20 |
-
if not os.path.exists(os.path.join(model_path, 'heavy_deg.pt')):
|
| 21 |
-
print(f"Model not found at {model_path}, downloading...")
|
| 22 |
-
response = requests.get(url, stream=True)
|
| 23 |
-
with open(os.path.join(model_path, 'heavy_deg.pt'), 'wb') as f:
|
| 24 |
-
for chunk in response.iter_content(chunk_size=1024):
|
| 25 |
-
if chunk:
|
| 26 |
-
f.write(chunk)
|
| 27 |
-
print(f"Model downloaded to {model_path}")
|
| 28 |
-
else:
|
| 29 |
-
print(f"Model found at {model_path}, skipping download.")
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
logger = get_logger()
|
| 33 |
-
|
| 34 |
-
class VideoToVideo_sr():
|
| 35 |
-
def __init__(self, opt, device=torch.device(f'cuda:0')):
|
| 36 |
-
self.opt = opt
|
| 37 |
-
self.device = device # torch.device(f'cuda:0')
|
| 38 |
-
|
| 39 |
-
# text_encoder
|
| 40 |
-
text_encoder = FrozenOpenCLIPEmbedder(device=self.device, pretrained="laion2b_s32b_b79k")
|
| 41 |
-
text_encoder.model.to(self.device)
|
| 42 |
-
self.text_encoder = text_encoder
|
| 43 |
-
logger.info(f'Build encoder with FrozenOpenCLIPEmbedder')
|
| 44 |
-
|
| 45 |
-
# U-Net with ControlNet
|
| 46 |
-
generator = ControlledV2VUNet()
|
| 47 |
-
generator = generator.to(self.device)
|
| 48 |
-
generator.eval()
|
| 49 |
-
|
| 50 |
-
# 确保 cfg.model_path 是文件夹路径,不要加上文件名
|
| 51 |
-
cfg.model_path = opt.model_path
|
| 52 |
-
# download weight
|
| 53 |
-
model_url = 'https://huggingface.co/SherryX/STAR/resolve/main/I2VGen-XL-based/heavy_deg.pt'
|
| 54 |
-
download_model(model_url, cfg.model_path)
|
| 55 |
-
|
| 56 |
-
# 拼接完整路径
|
| 57 |
-
model_file_path = os.path.join('pretrained_weight', 'I2VGen-XL-based', 'heavy_deg.pt')
|
| 58 |
-
print('model_file_path:', model_file_path)
|
| 59 |
-
|
| 60 |
-
# 加载模型
|
| 61 |
-
load_dict = torch.load(model_file_path, map_location='cpu')
|
| 62 |
-
|
| 63 |
-
if 'state_dict' in load_dict:
|
| 64 |
-
load_dict = load_dict['state_dict']
|
| 65 |
-
ret = generator.load_state_dict(load_dict, strict=False)
|
| 66 |
-
|
| 67 |
-
self.generator = generator.half()
|
| 68 |
-
logger.info('Load model path {}, with local status {}'.format(cfg.model_path, ret))
|
| 69 |
-
|
| 70 |
-
# Noise scheduler
|
| 71 |
-
sigmas = noise_schedule(
|
| 72 |
-
schedule='logsnr_cosine_interp',
|
| 73 |
-
n=1000,
|
| 74 |
-
zero_terminal_snr=True,
|
| 75 |
-
scale_min=2.0,
|
| 76 |
-
scale_max=4.0)
|
| 77 |
-
diffusion = GaussianDiffusion(sigmas=sigmas)
|
| 78 |
-
self.diffusion = diffusion
|
| 79 |
-
logger.info('Build diffusion with GaussianDiffusion')
|
| 80 |
-
|
| 81 |
-
# Temporal VAE
|
| 82 |
-
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
| 83 |
-
"stabilityai/stable-video-diffusion-img2vid", subfolder="vae", variant="fp16"
|
| 84 |
-
)
|
| 85 |
-
vae.eval()
|
| 86 |
-
vae.requires_grad_(False)
|
| 87 |
-
vae.to(self.device)
|
| 88 |
-
self.vae = vae
|
| 89 |
-
logger.info('Build Temporal VAE')
|
| 90 |
-
|
| 91 |
-
torch.cuda.empty_cache()
|
| 92 |
-
|
| 93 |
-
self.negative_prompt = cfg.negative_prompt
|
| 94 |
-
self.positive_prompt = cfg.positive_prompt
|
| 95 |
-
|
| 96 |
-
negative_y = text_encoder(self.negative_prompt).detach()
|
| 97 |
-
self.negative_y = negative_y
|
| 98 |
-
|
| 99 |
-
self.chunk_size = opt.chunk_size
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def test(self, input: Dict[str, Any], total_noise_levels=1000, \
|
| 103 |
-
steps=50, solver_mode='fast', guide_scale=7.5, max_chunk_len=32):
|
| 104 |
-
video_data = input['video_data']
|
| 105 |
-
y = input['y']
|
| 106 |
-
(target_h, target_w) = input['target_res']
|
| 107 |
-
|
| 108 |
-
video_data = F.interpolate(video_data, [target_h,target_w], mode='bilinear')
|
| 109 |
-
|
| 110 |
-
logger.info(f'video_data shape: {video_data.shape}')
|
| 111 |
-
frames_num, _, h, w = video_data.shape
|
| 112 |
-
|
| 113 |
-
padding = pad_to_fit(h, w)
|
| 114 |
-
video_data = F.pad(video_data, padding, 'constant', 1)
|
| 115 |
-
|
| 116 |
-
video_data = video_data.unsqueeze(0)
|
| 117 |
-
bs = 1
|
| 118 |
-
video_data = video_data.to(self.device)
|
| 119 |
-
|
| 120 |
-
video_data_feature = self.vae_encode(video_data)
|
| 121 |
-
torch.cuda.empty_cache()
|
| 122 |
-
|
| 123 |
-
y = self.text_encoder(y).detach()
|
| 124 |
-
|
| 125 |
-
with amp.autocast(enabled=True):
|
| 126 |
-
|
| 127 |
-
t = torch.LongTensor([total_noise_levels-1]).to(self.device)
|
| 128 |
-
noised_lr = self.diffusion.diffuse(video_data_feature, t)
|
| 129 |
-
|
| 130 |
-
model_kwargs = [{'y': y}, {'y': self.negative_y}]
|
| 131 |
-
model_kwargs.append({'hint': video_data_feature})
|
| 132 |
-
|
| 133 |
-
torch.cuda.empty_cache()
|
| 134 |
-
chunk_inds = make_chunks(frames_num, interp_f_num=0, max_chunk_len=max_chunk_len) if frames_num > max_chunk_len else None
|
| 135 |
-
|
| 136 |
-
solver = 'dpmpp_2m_sde' # 'heun' | 'dpmpp_2m_sde'
|
| 137 |
-
gen_vid = self.diffusion.sample_sr(
|
| 138 |
-
noise=noised_lr,
|
| 139 |
-
model=self.generator,
|
| 140 |
-
model_kwargs=model_kwargs,
|
| 141 |
-
guide_scale=guide_scale,
|
| 142 |
-
guide_rescale=0.2,
|
| 143 |
-
solver=solver,
|
| 144 |
-
solver_mode=solver_mode,
|
| 145 |
-
return_intermediate=None,
|
| 146 |
-
steps=steps,
|
| 147 |
-
t_max=total_noise_levels - 1,
|
| 148 |
-
t_min=0,
|
| 149 |
-
discretization='trailing',
|
| 150 |
-
chunk_inds=chunk_inds,)
|
| 151 |
-
torch.cuda.empty_cache()
|
| 152 |
-
|
| 153 |
-
logger.info(f'sampling, finished.')
|
| 154 |
-
vid_tensor_gen = self.vae_decode_chunk(gen_vid, chunk_size=self.chunk_size)
|
| 155 |
-
|
| 156 |
-
logger.info(f'temporal vae decoding, finished.')
|
| 157 |
-
|
| 158 |
-
w1, w2, h1, h2 = padding
|
| 159 |
-
vid_tensor_gen = vid_tensor_gen[:,:,h1:h+h1,w1:w+w1]
|
| 160 |
-
|
| 161 |
-
gen_video = rearrange(
|
| 162 |
-
vid_tensor_gen, '(b f) c h w -> b c f h w', b=bs)
|
| 163 |
-
|
| 164 |
-
torch.cuda.empty_cache()
|
| 165 |
-
|
| 166 |
-
return gen_video.type(torch.float32).cpu()
|
| 167 |
-
|
| 168 |
-
def temporal_vae_decode(self, z, num_f):
|
| 169 |
-
return self.vae.decode(z/self.vae.config.scaling_factor, num_frames=num_f).sample
|
| 170 |
-
|
| 171 |
-
def vae_decode_chunk(self, z, chunk_size=3):
|
| 172 |
-
z = rearrange(z, "b c f h w -> (b f) c h w")
|
| 173 |
-
video = []
|
| 174 |
-
for ind in range(0, z.shape[0], chunk_size):
|
| 175 |
-
num_f = z[ind:ind+chunk_size].shape[0]
|
| 176 |
-
video.append(self.temporal_vae_decode(z[ind:ind+chunk_size],num_f))
|
| 177 |
-
video = torch.cat(video)
|
| 178 |
-
return video
|
| 179 |
-
|
| 180 |
-
def vae_encode(self, t, chunk_size=1):
|
| 181 |
-
num_f = t.shape[1]
|
| 182 |
-
t = rearrange(t, "b f c h w -> (b f) c h w")
|
| 183 |
-
z_list = []
|
| 184 |
-
for ind in range(0,t.shape[0],chunk_size):
|
| 185 |
-
z_list.append(self.vae.encode(t[ind:ind+chunk_size]).latent_dist.sample())
|
| 186 |
-
z = torch.cat(z_list, dim=0)
|
| 187 |
-
z = rearrange(z, "(b f) c h w -> b c f h w", f=num_f)
|
| 188 |
-
return z * self.vae.config.scaling_factor
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def pad_to_fit(h, w):
|
| 192 |
-
BEST_H, BEST_W = 720, 1280
|
| 193 |
-
|
| 194 |
-
if h < BEST_H:
|
| 195 |
-
h1, h2 = _create_pad(h, BEST_H)
|
| 196 |
-
elif h == BEST_H:
|
| 197 |
-
h1 = h2 = 0
|
| 198 |
-
else:
|
| 199 |
-
h1 = 0
|
| 200 |
-
h2 = int((h + 48) // 64 * 64) + 64 - 48 - h
|
| 201 |
-
|
| 202 |
-
if w < BEST_W:
|
| 203 |
-
w1, w2 = _create_pad(w, BEST_W)
|
| 204 |
-
elif w == BEST_W:
|
| 205 |
-
w1 = w2 = 0
|
| 206 |
-
else:
|
| 207 |
-
w1 = 0
|
| 208 |
-
w2 = int(w // 64 * 64) + 64 - w
|
| 209 |
-
return (w1, w2, h1, h2)
|
| 210 |
-
|
| 211 |
-
def _create_pad(h, max_len):
|
| 212 |
-
h1 = int((max_len - h) // 2)
|
| 213 |
-
h2 = max_len - h1 - h
|
| 214 |
-
return h1, h2
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def make_chunks(f_num, interp_f_num, max_chunk_len, chunk_overlap_ratio=0.5):
|
| 218 |
-
MAX_CHUNK_LEN = max_chunk_len
|
| 219 |
-
MAX_O_LEN = MAX_CHUNK_LEN * chunk_overlap_ratio
|
| 220 |
-
chunk_len = int((MAX_CHUNK_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1)
|
| 221 |
-
o_len = int((MAX_O_LEN-1)//(1+interp_f_num)*(interp_f_num+1)+1)
|
| 222 |
-
chunk_inds = sliding_windows_1d(f_num, chunk_len, o_len)
|
| 223 |
-
return chunk_inds
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
def sliding_windows_1d(length, window_size, overlap_size):
|
| 227 |
-
stride = window_size - overlap_size
|
| 228 |
-
ind = 0
|
| 229 |
-
coords = []
|
| 230 |
-
while ind<length:
|
| 231 |
-
if ind+window_size*1.25>=length:
|
| 232 |
-
coords.append((ind,length))
|
| 233 |
-
break
|
| 234 |
-
else:
|
| 235 |
-
coords.append((ind,ind+window_size))
|
| 236 |
-
ind += stride
|
| 237 |
-
return coords
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|