semo / Semo /model /rectified_flow.py
HappyP4nda's picture
Upload folder using huggingface_hub
55f3ab3 verified
import torch
from .loss import (l1,l2)
import numpy as np
from tqdm import tqdm
class RectifiedFlow:
def __init__(self, num_steps=1000 ,theta=1e-5):
self.num_step = num_steps
self.theta = theta
def train_loss(self,model,z0,z1,loss_type='l2'):
zt,t,gt = self.get_train_tuple(z0,z1)
pre = model(zt,t)
loss_fn = eval(loss_type)
return loss_fn(pre,gt)
def get_train_tuple(self, z0=None, z1=None,time_step=None):
"""
Args:
z0 (torch.Tensor) : start distribution d0 , can be Gaussian noise. default shape : (n,c,h,w)
z1 (torch.Tensor) : target distribution d1 , default shape : (n,c,h,w)
time_step (torch.Tensor) : (batch_size,)
Returns:
z_t (torch.Tensor) : intermediate distribution z_t, default shape : (n,c,h,w)
t (torch.Tensor) : interpolation factor t, default shape : (n,1)
target (torch.Tensor) : target distribution, default shape : (n,c,h,w)
"""
if time_step is None:
t = torch.rand((z1.shape[0], 1))
else:
t = self.timestep_to_time(time_step) # (batch_size,1)
if z0 is None:
z0 = torch.randn_like(z1) # gaussian noise
z_t = t * z1 + (1.- t) * z0
target = z1 - z0
return z_t, target
def get_target_with_zt_vel(self, zt,vel,time_step):
t = self.timestep_to_time(time_step) # (batch_size,1)
target = zt + (1.-t)*vel
return target
@torch.no_grad()
def sample_ode_process(self,model=None, z0=None, N=None):
### NOTE: Use Euler method to sample from the learned flow
if N is None:
N = self.num_step
dt = 1./N
traj = [] # to store the trajectory
z = z0.detach().clone()
batchsize = z.shape[0]
traj.append(z.detach().clone())
for i in range(N):
t = torch.ones((batchsize,1)) * i / N
pred = model(z, t)
z = z.detach().clone() + pred * dt
traj.append(z.detach().clone())
return traj
@torch.no_grad()
def timestep_to_time(self,time_step):
# use for transform time_step(int) to time(float)
t = (self.num_step - time_step) / self.num_step
if len(t.shape) == 1:
t = t.view(-1,1,1,1)
return t
@torch.no_grad()
def sample_loop(self,model,video_start,sample_step,start_step=None,motion = None,motion_available_length=None,reinforce_condition=False):
"""predict z1 from z0
Args:
video_start : `torch.Tensor` (n,c,h,w)
motion : `torch.Tensor` (n,t,c,h,w)
start_step : `int` denoise start step
sample_step : `int` <= self.num.step
Returns:
z1 (n,c,h,w) : predict z1 from z0
"""
if start_step is None:
start_step = self.num_step
step_seq = np.linspace(0, start_step, num=sample_step, endpoint=True,dtype=int) # [0,5,10,15,....,start_step]
dt = 1./sample_step
sample = video_start
for i in tqdm(list(reversed(step_seq))):
# time_step
time_step = torch.ones((sample.shape[0],)).to(sample.device)
time_step = time_step * i
# input
if reinforce_condition:
zt = torch.cat((video_start,sample),dim=1)
else:
zt = sample
# forward
if motion is None:
pre = model(zt,time_step)
else:
pre = model(motion,zt,time_step,motion_available_length) # (n,c,h,w)
sample = sample + pre * dt
return sample # (n,c,h,w)
@torch.no_grad()
def sample_step(self,model,zt,sample_step,start_step=None):
"""
one-step denoise
Args:
model : torch.nn.Module
start_step : `int`
Returns:
z : torch.Tensor , z[t-1]
"""
if start_step is None:
start_step = self.num_step
dt = 1./sample_step
time_step = torch.ones((zt.shape[0],1)) * start_step
pre = model(zt,time_step)
z = zt + pre * dt
return z