import torch from tqdm import tqdm import numpy as np import math class RectifiedFlow(): def __init__(self, num_timesteps, warmup_timesteps = 10, noise_scale=1.0, init_type='gaussian', eps=1, sampling='logit', window_size=8): """ eps: A `float` number. The smallest time step to sample from. """ self.num_timesteps = num_timesteps self.warmup_timesteps = warmup_timesteps*num_timesteps self.T = 1000. self.noise_scale = noise_scale self.init_type = init_type self.eps = eps self.window_size = window_size self.sampling = sampling def logit(self, x): return torch.log(x / (1 - x)) def logit_normal(self, x, mu=0, sigma=1): return 1 / (sigma * math.sqrt(2 * torch.pi) * x * (1 - x)) * torch.exp(-(self.logit(x) - mu) ** 2 / (2 * sigma ** 2)) def training_loss(self, model, v, a, model_kwargs): """ v: [B, T, C, H, W] a: [B, T, N, F] """ B,T = v.shape[:2] tw = torch.rand((v.shape[0],1), device=v.device) window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v.device).unsqueeze(0).repeat(B,1) rollout = torch.bernoulli(torch.tensor(0.8).repeat(B).to(v.device)).bool() t_rollout = (window_indexes+tw)/self.window_size t_pre_rollout = window_indexes/self.window_size + tw t = torch.where(rollout.unsqueeze(1).repeat(1,self.window_size), t_rollout, t_pre_rollout) t = 1 - t # swap 0 and 1, since 1 is full image and 0 is full noise t = torch.clamp(t, 0+1e-6, 1-1e-6) if self.sampling == 'logit': weigths = self.logit_normal(t, mu=0, sigma=1) else: weigths = torch.ones_like(t) B, T = t.shape v_z0 = self.get_z0(v).to(v.device) a_z0 = self.get_z0(a).to(a.device) t_video = t.view(B,T,1,1,1).repeat(1,1,v.shape[2], v.shape[3], v.shape[4]) t_audio = t.view(B,T,1,1,1).repeat(1,1,a.shape[2], a.shape[3], a.shape[4]) perturbed_video = t_video*v + (1-t_video)*v_z0 perturbed_audio = t_audio*a + (1-t_audio)*a_z0 t_rf = t*(self.T-self.eps) + self.eps score_v, score_a = model(perturbed_video, perturbed_audio, t_rf, **model_kwargs) # score_v = [B, T, C, H, W] # score_a = [B, T, N, F] target_video = v - v_z0 # direction of the flow target_audio = a - a_z0 # direction of the flow loss_video = torch.square(score_v-target_video) loss_audio = torch.square(score_a-target_audio) loss_video = torch.mean(loss_video, dim=[2,3,4]) loss_audio = torch.mean(loss_audio, dim=[2,3,4]) #mask out the loss for the time steps that are greater than T loss_video = loss_video * (weigths) loss_video = torch.mean(loss_video) loss_audio = loss_audio * (weigths) loss_audio = torch.mean(loss_audio) return {"loss": (loss_video + loss_audio)} def sample(self, model, v_z, a_z, model_kwargs, progress=True): B = v_z.shape[0] window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v_z.device).unsqueeze(0).repeat(B,1) # warm up with different number of warmup timestep to be more precise for i in tqdm(range(self.warmup_timesteps), disable=not progress): dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i) score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) v_z = v_z.detach().clone() + dt*score_v a_z = a_z.detach().clone() + dt*score_a v_f = v_z[:,0] a_f = a_z[:,0] v_z = torch.cat([v_z[:,1:], torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1) a_z = torch.cat([a_z[:,1:], torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1) def yield_frame(): nonlocal v_z, a_z, window_indexes yield (v_f, a_f) dt = 1/(self.num_timesteps*self.window_size) while True: for i in range(self.num_timesteps): tw = (self.num_timesteps - i)/self.num_timesteps t = (window_indexes + tw)/self.window_size t = 1-t t_rf = t*(self.T-self.eps) + self.eps score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) v_z = v_z.detach().clone() + dt*score_v a_z = a_z.detach().clone() + dt*score_a v = v_z[:,0] a = a_z[:,0] #remove the first element v_noise = torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale a_noise = torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale v_z = torch.cat([v_z[:,1:],v_noise], dim=1) a_z = torch.cat([a_z[:,1:],a_noise], dim=1) yield (v, a) return yield_frame def sample_a2v(self, model, v_z, a, model_kwargs, scale=1, progress=True): B = v_z.shape[0] window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v_z.device).unsqueeze(0).repeat(B,1) a_partial = a[:, :self.window_size] a_noise = torch.randn_like(a, device=v_z.device)*self.noise_scale a_noise_partial = a_noise[:, :self.window_size] with torch.enable_grad(): # warm up with different number of warmup timestep to be more precise for i in tqdm(range(self.warmup_timesteps), disable=not progress): v_z = v_z.detach().requires_grad_(True) dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i) a_z = a_partial*t_partial + a_noise_partial*(1-t_partial) score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) loss = torch.square((a_partial-a_noise_partial)-score_a) grad = torch.autograd.grad(loss.mean(), v_z)[0] v_z = v_z.detach() + dt*score_v - ((t_partial+dt)!=1) * dt * grad * scale v_f = v_z[:,0].detach() v_z = torch.cat([v_z[:,1:], torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1) def yield_frame(): nonlocal v_z, a, a_noise, window_indexes yield v_f dt = 1/(self.num_timesteps*self.window_size) while True: torch.cuda.empty_cache() a = a[:,1:] a_noise = a_noise[:,1:] if a.shape[1] < self.window_size: a = torch.cat([a, torch.randn_like(a[:,0]).unsqueeze(1)*self.noise_scale], dim=1) a_noise = torch.cat([a_noise, torch.randn_like(a[:,0]).unsqueeze(1)*self.noise_scale], dim=1) a_partial = a[:, :self.window_size] a_noise_partial = a_noise[:, :self.window_size] with torch.enable_grad(): for i in range(self.num_timesteps): v_z = v_z.detach().requires_grad_(True) tw = (self.num_timesteps - i)/self.num_timesteps t = (window_indexes + tw)/self.window_size t = 1-t t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) t_rf = t*(self.T-self.eps) + self.eps a_z = a_partial*t_partial + torch.randn_like(a_partial, device=v_z.device)*self.noise_scale*(1-t_partial) score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) loss = torch.square((a_partial-a_noise_partial)-score_a) grad = torch.autograd.grad(loss.mean(), v_z)[0] v_z = v_z.detach() + dt*score_v - ((t_partial+dt)!=1) * dt * grad * scale v = v_z[:,0].detach() v_noise = torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale v_z = torch.cat([v_z[:,1:],v_noise], dim=1) yield v return yield_frame def sample_v2a(self, model, v, a_z, model_kwargs, scale=2, progress=True): B = a_z.shape[0] window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=a_z.device).unsqueeze(0).repeat(B,1) v_partial = v[:, :self.window_size] v_noise = torch.randn_like(v, device=a_z.device)*self.noise_scale v_noise_partial = v_noise[:, :self.window_size] with torch.enable_grad(): # warm up with different number of warmup timestep to be more precise for i in tqdm(range(self.warmup_timesteps), disable=not progress): a_z = a_z.detach().requires_grad_(True) dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i) v_z = v_partial*t_partial + v_noise_partial*(1-t_partial) score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) loss = torch.square((v_partial-v_noise_partial)-score_v) grad = torch.autograd.grad(loss.mean(), a_z)[0] a_z = a_z.detach() + dt*score_a - ((t_partial + dt)!=1) * dt * grad * scale a_f = a_z[:,0].detach() a_z = torch.cat([a_z[:,1:], torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1) def yield_frame(): nonlocal v, v_noise, a_z, window_indexes yield a_f dt = 1/(self.num_timesteps*self.window_size) while True: torch.cuda.empty_cache() v = v[:,1:] v_noise = v_noise[:,1:] if v.shape[1] < self.window_size: v = torch.cat([v, torch.randn_like(v[:,0]).unsqueeze(1)*self.noise_scale], dim=1) v_noise = torch.cat([v, torch.randn_like(v[:,0]).unsqueeze(1)*self.noise_scale], dim=1) v_partial = v[:, :self.window_size] v_noise_partial = v_noise[:, :self.window_size] with torch.enable_grad(): for i in range(self.num_timesteps): a_z = a_z.detach().requires_grad_(True) tw = (self.num_timesteps - i)/self.num_timesteps t = (window_indexes + tw)/self.window_size t = 1-t t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) t_rf = t*(self.T-self.eps) + self.eps v_z = v_partial*t_partial + v_noise_partial*(1-t_partial) score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) loss = torch.square((v_partial-v_noise_partial)-score_v) grad = torch.autograd.grad(loss.mean(), a_z)[0] a_z = a_z.detach() + dt*score_a - ((t_partial + dt)!=1) * dt * grad * scale a = a_z[:,0].detach() a_noise = torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale a_z = torch.cat([a_z[:,1:],a_noise], dim=1) yield a return yield_frame def calculate_prerolling_timestep(self, window_indexes, i): tw = (self.warmup_timesteps - i)/self.warmup_timesteps tw_future = (self.warmup_timesteps - (i+1))/self.warmup_timesteps t = window_indexes/self.window_size + tw #timestep for the next iteration, to calculate dt t_future = window_indexes/self.window_size + tw_future #Swap 0 with 1, 1 is full image, 0 is full noise t = 1-t t_future = 1 - t_future t = torch.clamp(t, 0, 1) t_future = torch.clamp(t_future, 0, 1) dt = torch.abs(t_future-t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # [B, window_size, 1, 1, 1] t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) t_rf= t*(self.T-self.eps) + self.eps return dt,t_partial,t_rf def get_z0(self, batch, train=True): if self.init_type == 'gaussian': ### standard gaussian #+ 0.5 return torch.randn(batch.shape)*self.noise_scale else: raise NotImplementedError("INITIALIZATION TYPE NOT IMPLEMENTED")