|
|
from vitra.models.action_model.dit import DiT |
|
|
from vitra.models.action_model import create_diffusion |
|
|
from . import gaussian_diffusion as gd |
|
|
from vitra.datasets.dataset_utils import ActionFeature |
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
def DiT_T(**kwargs): |
|
|
return DiT(depth=3, hidden_size=256, num_heads=4, **kwargs) |
|
|
def DiT_S(**kwargs): |
|
|
return DiT(depth=6, hidden_size=384, num_heads=4, **kwargs) |
|
|
def DiT_M(**kwargs): |
|
|
return DiT(depth=12, hidden_size=384, num_heads=6, **kwargs) |
|
|
def DiT_B(**kwargs): |
|
|
return DiT(depth=12, hidden_size=768, num_heads=12, **kwargs) |
|
|
def DiT_L(**kwargs): |
|
|
return DiT(depth=24, hidden_size=1024, num_heads=16, **kwargs) |
|
|
|
|
|
DiT_models = {'DiT-S': DiT_S, 'DiT-M': DiT_M, 'DiT-B': DiT_B, 'DiT-T': DiT_T, 'DiT-L': DiT_L} |
|
|
|
|
|
class DiffusionPolicy(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
token_size, |
|
|
model_type='DiT-B', |
|
|
in_channels=192, |
|
|
future_action_window_size=16, |
|
|
past_action_window_size=0, |
|
|
use_state=None, |
|
|
action_type='angle', |
|
|
diffusion_steps=100, |
|
|
state_dim=None, |
|
|
loss_type='human', |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.in_channels = in_channels |
|
|
|
|
|
self.diffusion_steps = diffusion_steps |
|
|
self.diffusion = create_diffusion(timestep_respacing="", |
|
|
noise_schedule = 'squaredcos_cap_v2', |
|
|
diffusion_steps=self.diffusion_steps, |
|
|
sigma_small=True, |
|
|
learn_sigma = False |
|
|
) |
|
|
|
|
|
self.ddim_diffusion = None |
|
|
if self.diffusion.model_var_type in [gd.ModelVarType.LEARNED, gd.ModelVarType.LEARNED_RANGE]: |
|
|
learn_sigma = True |
|
|
else: |
|
|
learn_sigma = False |
|
|
self.past_action_window_size = past_action_window_size |
|
|
self.future_action_window_size = future_action_window_size |
|
|
self.use_state = use_state |
|
|
self.action_type = action_type |
|
|
|
|
|
|
|
|
if loss_type == 'human': |
|
|
self.loss_components = ActionFeature.get_loss_components(action_type) |
|
|
elif loss_type == 'robot': |
|
|
self.loss_components = ActionFeature.get_xhand_loss_components() |
|
|
else: |
|
|
raise ValueError(f"Unknown loss_type: {loss_type}") |
|
|
self.net = DiT_models[model_type]( |
|
|
token_size = token_size, |
|
|
action_dim = in_channels, |
|
|
class_dropout_prob = 0.1, |
|
|
learn_sigma = learn_sigma, |
|
|
future_action_window_size = future_action_window_size, |
|
|
past_action_window_size = past_action_window_size, |
|
|
use_state = use_state, |
|
|
state_dim=state_dim |
|
|
) |
|
|
|
|
|
|
|
|
def loss(self, x, z, x_mask, state=None, state_mask=None): |
|
|
|
|
|
noise = torch.randn_like(x) |
|
|
timestep = torch.randint(0, self.diffusion.num_timesteps, (x.size(0),), device= x.device) |
|
|
|
|
|
|
|
|
x_t = self.diffusion.q_sample(x, timestep, noise) |
|
|
|
|
|
|
|
|
noise_pred = self.net(x_t, timestep, z, x_mask, state, state_mask) |
|
|
|
|
|
assert noise_pred.shape == noise.shape == x.shape |
|
|
|
|
|
|
|
|
square_delta = (noise_pred - noise) ** 2 * x_mask |
|
|
|
|
|
|
|
|
def mask_loss(from_dim, to_dim): |
|
|
s = square_delta[:, :, from_dim:to_dim].sum() |
|
|
n = x_mask[:, :, from_dim:to_dim].sum() |
|
|
return s / n if n > 0 else 0 |
|
|
|
|
|
|
|
|
component_losses = {} |
|
|
component_counts = {} |
|
|
|
|
|
for name, (start, end, weight) in self.loss_components.items(): |
|
|
component_losses[name] = mask_loss(start, end) * weight |
|
|
component_counts[name] = x_mask[:, :, start].sum() |
|
|
|
|
|
total_count = sum(component_counts.values()) |
|
|
|
|
|
if total_count == 0: |
|
|
loss = square_delta[0, 0, 0] |
|
|
else: |
|
|
loss = sum( |
|
|
component_losses[k] * component_counts[k] |
|
|
for k in component_counts.keys() |
|
|
) / total_count |
|
|
|
|
|
|
|
|
return { |
|
|
"loss": loss, |
|
|
**component_losses, |
|
|
} |
|
|
|
|
|
|
|
|
def sample(self, |
|
|
action_features, |
|
|
cfg_scale, |
|
|
current_state, |
|
|
current_state_mask, |
|
|
use_ddim, |
|
|
num_ddim_steps, |
|
|
action_masks, |
|
|
): |
|
|
B = action_features.shape[0] |
|
|
noise = torch.randn(action_features.shape[0], self.future_action_window_size+1, |
|
|
self.in_channels, device=action_features.device) |
|
|
|
|
|
x_mask = action_masks.to(action_features.device) |
|
|
|
|
|
using_cfg = cfg_scale > 1.0 |
|
|
if using_cfg: |
|
|
noise = torch.cat([noise, noise], 0) |
|
|
uncondition = self.net.z_embedder.uncondition |
|
|
uncondition = uncondition.unsqueeze(0) |
|
|
uncondition = uncondition.expand(B, 1, -1) |
|
|
z = torch.cat([action_features, uncondition], 0) |
|
|
cfg_scale = cfg_scale |
|
|
|
|
|
if self.use_state == 'DiT': |
|
|
model_kwargs = dict( |
|
|
z=z, x_mask=x_mask, |
|
|
cfg_scale=cfg_scale, state=current_state, |
|
|
state_mask=current_state_mask |
|
|
) |
|
|
else: |
|
|
model_kwargs = dict(z=z, x_mask=x_mask, cfg_scale=cfg_scale) |
|
|
sample_fn = self.net.forward_with_cfg |
|
|
else: |
|
|
if self.use_state == 'DiT': |
|
|
model_kwargs = dict(z=action_features, x_mask=x_mask, state=current_state, state_mask=current_state_mask) |
|
|
else: |
|
|
model_kwargs = dict(z=action_features, x_mask=x_mask) |
|
|
sample_fn = self.net.forward |
|
|
|
|
|
if use_ddim and num_ddim_steps is not None: |
|
|
if self.ddim_diffusion is None: |
|
|
self.create_ddim(ddim_step=num_ddim_steps) |
|
|
samples = self.ddim_diffusion.ddim_sample_loop( |
|
|
sample_fn, |
|
|
noise.shape, |
|
|
noise, |
|
|
clip_denoised=False, |
|
|
model_kwargs=model_kwargs, |
|
|
progress=False, |
|
|
device=action_features.device, |
|
|
eta=0.0 |
|
|
) |
|
|
else: |
|
|
samples = self.ddim_diffusion.diffusion.p_sample_loop( |
|
|
sample_fn, |
|
|
noise.shape, |
|
|
noise, |
|
|
clip_denoised=False, |
|
|
model_kwargs=model_kwargs, |
|
|
progress=False, |
|
|
device=action_features.device |
|
|
) |
|
|
if using_cfg: |
|
|
samples, _ = samples.chunk(2, dim=0) |
|
|
return samples |
|
|
|
|
|
|
|
|
def create_ddim(self, ddim_step=10): |
|
|
self.ddim_diffusion = create_diffusion( |
|
|
timestep_respacing="ddim"+str(ddim_step), |
|
|
noise_schedule = 'squaredcos_cap_v2', |
|
|
diffusion_steps=self.diffusion_steps, |
|
|
sigma_small=True, |
|
|
learn_sigma = False |
|
|
) |
|
|
return self.ddim_diffusion |