File size: 3,877 Bytes
bd546bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
from .loss import (l1,l2)
import numpy as np
from tqdm import tqdm
class RectifiedFlow:
def __init__(self, num_steps=1000 ,theta=1e-5):
self.num_step = num_steps
self.theta = theta
def train_loss(self,model,z0,z1,loss_type='l2'):
zt,t,gt = self.get_train_tuple(z0,z1)
pre = model(zt,t)
loss_fn = eval(loss_type)
return loss_fn(pre,gt)
def get_train_tuple(self, z0=None, z1=None,time_step=None):
"""
Args:
z0 (torch.Tensor) : start distribution d0 , can be Gaussian noise. default shape : (n,c,h,w)
z1 (torch.Tensor) : target distribution d1 , default shape : (n,c,h,w)
time_step (torch.Tensor) : (batch_size,)
Returns:
z_t (torch.Tensor) : intermediate distribution z_t, default shape : (n,c,h,w)
t (torch.Tensor) : interpolation factor t, default shape : (n,1)
target (torch.Tensor) : target distribution, default shape : (n,c,h,w)
"""
if time_step is None:
t = torch.rand((z1.shape[0], 1))
else:
t = self.timestep_to_time(time_step) # (batch_size,1)
if z0 is None:
z0 = torch.randn_like(z1) # gaussian noise
z_t = t * z1 + (1.- t) * z0
target = z1 - z0
return z_t, target
def get_target_with_zt_vel(self, zt,vel,time_step):
t = self.timestep_to_time(time_step) # (batch_size,1)
target = zt + (1.-t)*vel
return target
@torch.no_grad()
def sample_ode_process(self,model=None, z0=None, N=None):
### NOTE: Use Euler method to sample from the learned flow
if N is None:
N = self.num_step
dt = 1./N
traj = [] # to store the trajectory
z = z0.detach().clone()
batchsize = z.shape[0]
traj.append(z.detach().clone())
for i in range(N):
t = torch.ones((batchsize,1)) * i / N
pred = model(z, t)
z = z.detach().clone() + pred * dt
traj.append(z.detach().clone())
return traj
@torch.no_grad()
def timestep_to_time(self,time_step):
# use for transform time_step(int) to time(float)
t = (self.num_step - time_step) / self.num_step
if len(t.shape) == 1:
t = t.view(-1,1,1,1)
return t
@torch.no_grad()
def sample_loop(self,model,video_start,sample_step,start_step=None,motion = None,motion_available_length=None,reinforce_condition=False):
"""predict z1 from z0
Args:
video_start : `torch.Tensor` (n,c,h,w)
motion : `torch.Tensor` (n,t,c,h,w)
start_step : `int` denoise start step
sample_step : `int` <= self.num.step
Returns:
z1 (n,c,h,w) : predict z1 from z0
"""
if start_step is None:
start_step = self.num_step
step_seq = np.linspace(0, start_step, num=sample_step, endpoint=True,dtype=int) # [0,5,10,15,....,start_step]
dt = 1./sample_step
sample = video_start
for i in tqdm(list(reversed(step_seq))):
# time_step
time_step = torch.ones((sample.shape[0],)).to(sample.device)
time_step = time_step * i
# input
if reinforce_condition:
zt = torch.cat((video_start,sample),dim=1)
else:
zt = sample
# forward
if motion is None:
pre = model(zt,time_step)
else:
pre = model(motion,zt,time_step,motion_available_length) # (n,c,h,w)
sample = sample + pre * dt
return sample # (n,c,h,w)
@torch.no_grad()
def sample_step(self,model,zt,sample_step,start_step=None):
"""
one-step denoise
Args:
model : torch.nn.Module
start_step : `int`
Returns:
z : torch.Tensor , z[t-1]
"""
if start_step is None:
start_step = self.num_step
dt = 1./sample_step
time_step = torch.ones((zt.shape[0],1)) * start_step
pre = model(zt,time_step)
z = zt + pre * dt
return z
|