| import torch |
| import torch.nn as nn |
| import numpy as np |
| |
| import torch.nn.functional as F |
| import clip |
| from einops import rearrange, repeat |
| import math |
| from random import random |
| from tqdm.auto import tqdm |
| from typing import Callable, Optional, List, Dict |
| from copy import deepcopy |
| from functools import partial |
| from models.mask_transformer.tools import * |
| from torch.distributions.categorical import Categorical |
|
|
| class InputProcess(nn.Module): |
| def __init__(self, input_feats, latent_dim): |
| super().__init__() |
| self.input_feats = input_feats |
| self.latent_dim = latent_dim |
| self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) |
|
|
| def forward(self, x): |
| |
| x = x.permute((1, 0, 2)) |
| |
| x = self.poseEmbedding(x) |
| return x |
|
|
| class PositionalEncoding(nn.Module): |
| |
| def __init__(self, d_model, dropout=0.1, max_len=5000): |
| super(PositionalEncoding, self).__init__() |
| self.dropout = nn.Dropout(p=dropout) |
|
|
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0).transpose(0, 1) |
|
|
| self.register_buffer('pe', pe) |
|
|
| def forward(self, x): |
| |
| x = x + self.pe[:x.shape[0], :] |
| return self.dropout(x) |
|
|
| class OutputProcess_Bert(nn.Module): |
| def __init__(self, out_feats, latent_dim): |
| super().__init__() |
| self.dense = nn.Linear(latent_dim, latent_dim) |
| self.transform_act_fn = F.gelu |
| self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12) |
| self.poseFinal = nn.Linear(latent_dim, out_feats) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.transform_act_fn(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states) |
| output = self.poseFinal(hidden_states) |
| output = output.permute(1, 2, 0) |
| return output |
|
|
| class OutputProcess(nn.Module): |
| def __init__(self, out_feats, latent_dim): |
| super().__init__() |
| self.dense = nn.Linear(latent_dim, latent_dim) |
| self.transform_act_fn = F.gelu |
| self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12) |
| self.poseFinal = nn.Linear(latent_dim, out_feats) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.transform_act_fn(hidden_states) |
| hidden_states = self.LayerNorm(hidden_states) |
| output = self.poseFinal(hidden_states) |
| output = output.permute(1, 2, 0) |
| return output |
|
|
|
|
| class MaskTransformer(nn.Module): |
| def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, |
| num_heads=4, dropout=0.1, clip_dim=512, cond_drop_prob=0.1, |
| clip_version=None, opt=None, **kargs): |
| super(MaskTransformer, self).__init__() |
| print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}') |
|
|
| self.code_dim = code_dim |
| self.latent_dim = latent_dim |
| self.clip_dim = clip_dim |
| self.dropout = dropout |
| self.opt = opt |
|
|
| self.cond_mode = cond_mode |
| self.cond_drop_prob = cond_drop_prob |
|
|
| if self.cond_mode == 'action': |
| assert 'num_actions' in kargs |
| self.num_actions = kargs.get('num_actions', 1) |
|
|
| ''' |
| Preparing Networks |
| ''' |
| self.input_process = InputProcess(self.code_dim, self.latent_dim) |
| self.position_enc = PositionalEncoding(self.latent_dim, self.dropout) |
|
|
| seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, |
| nhead=num_heads, |
| dim_feedforward=ff_size, |
| dropout=dropout, |
| activation='gelu') |
|
|
| self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, |
| num_layers=num_layers) |
|
|
| self.encode_action = partial(F.one_hot, num_classes=self.num_actions) |
|
|
| |
| if self.cond_mode == 'text': |
| self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim) |
| elif self.cond_mode == 'action': |
| self.cond_emb = nn.Linear(self.num_actions, self.latent_dim) |
| elif self.cond_mode == 'uncond': |
| self.cond_emb = nn.Identity() |
| else: |
| raise KeyError("Unsupported condition mode!!!") |
|
|
|
|
| _num_tokens = opt.num_tokens + 2 |
| self.mask_id = opt.num_tokens |
| self.pad_id = opt.num_tokens + 1 |
|
|
| self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim) |
|
|
| self.token_emb = nn.Embedding(_num_tokens, self.code_dim) |
|
|
| self.apply(self.__init_weights) |
|
|
| ''' |
| Preparing frozen weights |
| ''' |
|
|
| if self.cond_mode == 'text': |
| print('Loading CLIP...') |
| self.clip_version = clip_version |
| self.clip_model = self.load_and_freeze_clip(clip_version) |
|
|
| self.noise_schedule = cosine_schedule |
|
|
| def load_and_freeze_token_emb(self, codebook): |
| ''' |
| :param codebook: (c, d) |
| :return: |
| ''' |
| assert self.training, 'Only necessary in training mode' |
| c, d = codebook.shape |
| self.token_emb.weight = nn.Parameter(torch.cat([codebook, torch.zeros(size=(2, d), device=codebook.device)], dim=0)) |
| self.token_emb.requires_grad_(False) |
| |
| |
| print("Token embedding initialized!") |
|
|
| def __init_weights(self, module): |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| def parameters_wo_clip(self): |
| return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] |
|
|
| def load_and_freeze_clip(self, clip_version): |
| clip_model, clip_preprocess = clip.load(clip_version, device='cpu', |
| jit=False) |
| |
| clip.model.convert_weights( |
| clip_model) |
| |
|
|
| |
| clip_model.eval() |
| for p in clip_model.parameters(): |
| p.requires_grad = False |
|
|
| return clip_model |
|
|
| def encode_text(self, raw_text): |
| device = next(self.parameters()).device |
| text = clip.tokenize(raw_text, truncate=True).to(device) |
| feat_clip_text = self.clip_model.encode_text(text).float() |
| return feat_clip_text |
|
|
| def mask_cond(self, cond, force_mask=False): |
| bs, d = cond.shape |
| if force_mask: |
| return torch.zeros_like(cond) |
| elif self.training and self.cond_drop_prob > 0.: |
| mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1) |
| return cond * (1. - mask) |
| else: |
| return cond |
|
|
| def trans_forward(self, motion_ids, cond, padding_mask, force_mask=False): |
| ''' |
| :param motion_ids: (b, seqlen) |
| :padding_mask: (b, seqlen), all pad positions are TRUE else FALSE |
| :param cond: (b, embed_dim) for text, (b, num_actions) for action |
| :param force_mask: boolean |
| :return: |
| -logits: (b, num_token, seqlen) |
| ''' |
|
|
| cond = self.mask_cond(cond, force_mask=force_mask) |
|
|
| |
| x = self.token_emb(motion_ids) |
| |
| |
| x = self.input_process(x) |
|
|
| cond = self.cond_emb(cond).unsqueeze(0) |
|
|
| x = self.position_enc(x) |
| xseq = torch.cat([cond, x], dim=0) |
|
|
| padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1) |
| |
|
|
| |
|
|
| output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[1:] |
| logits = self.output_process(output) |
| return logits |
|
|
| def forward(self, ids, y, m_lens): |
| ''' |
| :param ids: (b, n) |
| :param y: raw text for cond_mode=text, (b, ) for cond_mode=action |
| :m_lens: (b,) |
| :return: |
| ''' |
|
|
| bs, ntokens = ids.shape |
| device = ids.device |
|
|
| |
| non_pad_mask = lengths_to_mask(m_lens, ntokens) |
| ids = torch.where(non_pad_mask, ids, self.pad_id) |
|
|
| force_mask = False |
| if self.cond_mode == 'text': |
| with torch.no_grad(): |
| cond_vector = self.encode_text(y) |
| elif self.cond_mode == 'action': |
| cond_vector = self.enc_action(y).to(device).float() |
| elif self.cond_mode == 'uncond': |
| cond_vector = torch.zeros(bs, self.latent_dim).float().to(device) |
| force_mask = True |
| else: |
| raise NotImplementedError("Unsupported condition mode!!!") |
|
|
|
|
| ''' |
| Prepare mask |
| ''' |
| rand_time = uniform((bs,), device=device) |
| rand_mask_probs = self.noise_schedule(rand_time) |
| num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1) |
|
|
| batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1) |
| |
| mask = batch_randperm < num_token_masked.unsqueeze(-1) |
|
|
| |
| mask &= non_pad_mask |
|
|
| |
| labels = torch.where(mask, ids, self.mask_id) |
|
|
| x_ids = ids.clone() |
|
|
| |
| |
| mask_rid = get_mask_subset_prob(mask, 0.1) |
| rand_id = torch.randint_like(x_ids, high=self.opt.num_tokens) |
| x_ids = torch.where(mask_rid, rand_id, x_ids) |
| |
| mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88) |
|
|
| |
|
|
| x_ids = torch.where(mask_mid, self.mask_id, x_ids) |
|
|
| logits = self.trans_forward(x_ids, cond_vector, ~non_pad_mask, force_mask) |
| ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id) |
|
|
| return ce_loss, pred_id, acc |
|
|
| def forward_with_cond_scale(self, |
| motion_ids, |
| cond_vector, |
| padding_mask, |
| cond_scale=3, |
| force_mask=False): |
| |
| |
| if force_mask: |
| return self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True) |
|
|
| logits = self.trans_forward(motion_ids, cond_vector, padding_mask) |
| if cond_scale == 1: |
| return logits |
|
|
| aux_logits = self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True) |
|
|
| scaled_logits = aux_logits + (logits - aux_logits) * cond_scale |
| return scaled_logits |
|
|
| @torch.no_grad() |
| @eval_decorator |
| def generate(self, |
| conds, |
| m_lens, |
| timesteps: int, |
| cond_scale: int, |
| temperature=1, |
| topk_filter_thres=0.9, |
| gsample=False, |
| force_mask=False |
| ): |
| |
| |
|
|
| device = next(self.parameters()).device |
| seq_len = max(m_lens) |
| batch_size = len(m_lens) |
|
|
| if self.cond_mode == 'text': |
| with torch.no_grad(): |
| cond_vector = self.encode_text(conds) |
| elif self.cond_mode == 'action': |
| cond_vector = self.enc_action(conds).to(device) |
| elif self.cond_mode == 'uncond': |
| cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) |
| else: |
| raise NotImplementedError("Unsupported condition mode!!!") |
|
|
| padding_mask = ~lengths_to_mask(m_lens, seq_len) |
| |
|
|
| |
| ids = torch.where(padding_mask, self.pad_id, self.mask_id) |
| scores = torch.where(padding_mask, 1e5, 0.) |
| starting_temperature = temperature |
|
|
| for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))): |
| |
| rand_mask_prob = self.noise_schedule(timestep) |
|
|
| ''' |
| Maskout, and cope with variable length |
| ''' |
| |
| num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(min=1) |
|
|
| |
| sorted_indices = scores.argsort( |
| dim=1) |
| ranks = sorted_indices.argsort(dim=1) |
| is_mask = (ranks < num_token_masked.unsqueeze(-1)) |
| ids = torch.where(is_mask, self.mask_id, ids) |
|
|
| ''' |
| Preparing input |
| ''' |
| |
| logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector, |
| padding_mask=padding_mask, |
| cond_scale=cond_scale, |
| force_mask=force_mask) |
|
|
| logits = logits.permute(0, 2, 1) |
| |
| |
| filtered_logits = top_k(logits, topk_filter_thres, dim=-1) |
|
|
| ''' |
| Update ids |
| ''' |
| |
| temperature = starting_temperature |
| |
| |
| |
| |
| |
| if gsample: |
| |
| pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) |
| else: |
| |
| probs = F.softmax(filtered_logits, dim=-1) |
| |
| |
| pred_ids = Categorical(probs / temperature).sample() |
|
|
| |
| |
| ids = torch.where(is_mask, pred_ids, ids) |
|
|
| ''' |
| Updating scores |
| ''' |
| probs_without_temperature = logits.softmax(dim=-1) |
| scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) |
| scores = scores.squeeze(-1) |
|
|
| |
| scores = scores.masked_fill(~is_mask, 1e5) |
|
|
| ids = torch.where(padding_mask, -1, ids) |
| |
| return ids |
|
|
|
|
| @torch.no_grad() |
| @eval_decorator |
| def edit(self, |
| conds, |
| tokens, |
| m_lens, |
| timesteps: int, |
| cond_scale: int, |
| temperature=1, |
| topk_filter_thres=0.9, |
| gsample=False, |
| force_mask=False, |
| edit_mask=None, |
| padding_mask=None, |
| ): |
|
|
| assert edit_mask.shape == tokens.shape if edit_mask is not None else True |
| device = next(self.parameters()).device |
| seq_len = tokens.shape[1] |
|
|
| if self.cond_mode == 'text': |
| with torch.no_grad(): |
| cond_vector = self.encode_text(conds) |
| elif self.cond_mode == 'action': |
| cond_vector = self.enc_action(conds).to(device) |
| elif self.cond_mode == 'uncond': |
| cond_vector = torch.zeros(1, self.latent_dim).float().to(device) |
| else: |
| raise NotImplementedError("Unsupported condition mode!!!") |
|
|
| if padding_mask == None: |
| padding_mask = ~lengths_to_mask(m_lens, seq_len) |
|
|
| |
| if edit_mask == None: |
| mask_free = True |
| ids = torch.where(padding_mask, self.pad_id, tokens) |
| edit_mask = torch.ones_like(padding_mask) |
| edit_mask = edit_mask & ~padding_mask |
| edit_len = edit_mask.sum(dim=-1) |
| scores = torch.where(edit_mask, 0., 1e5) |
| else: |
| mask_free = False |
| edit_mask = edit_mask & ~padding_mask |
| edit_len = edit_mask.sum(dim=-1) |
| ids = torch.where(edit_mask, self.mask_id, tokens) |
| scores = torch.where(edit_mask, 0., 1e5) |
| starting_temperature = temperature |
|
|
| for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))): |
| |
| rand_mask_prob = 0.16 if mask_free else self.noise_schedule(timestep) |
|
|
| ''' |
| Maskout, and cope with variable length |
| ''' |
| |
| num_token_masked = torch.round(rand_mask_prob * edit_len).clamp(min=1) |
|
|
| |
| sorted_indices = scores.argsort( |
| dim=1) |
| ranks = sorted_indices.argsort(dim=1) |
| is_mask = (ranks < num_token_masked.unsqueeze(-1)) |
| |
| ids = torch.where(is_mask, self.mask_id, ids) |
|
|
| ''' |
| Preparing input |
| ''' |
| |
| logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector, |
| padding_mask=padding_mask, |
| cond_scale=cond_scale, |
| force_mask=force_mask) |
|
|
| logits = logits.permute(0, 2, 1) |
| |
| |
| filtered_logits = top_k(logits, topk_filter_thres, dim=-1) |
|
|
| ''' |
| Update ids |
| ''' |
| |
| temperature = starting_temperature |
| |
| |
| |
| |
| |
| if gsample: |
| |
| pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) |
| else: |
| |
| probs = F.softmax(filtered_logits, dim=-1) |
| |
| |
| pred_ids = Categorical(probs / temperature).sample() |
|
|
| |
| |
| ids = torch.where(is_mask, pred_ids, ids) |
|
|
| ''' |
| Updating scores |
| ''' |
| probs_without_temperature = logits.softmax(dim=-1) |
| scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) |
| scores = scores.squeeze(-1) |
|
|
| |
| scores = scores.masked_fill(~edit_mask, 1e5) if mask_free else scores.masked_fill(~is_mask, 1e5) |
|
|
| ids = torch.where(padding_mask, -1, ids) |
| |
| return ids |
|
|
| @torch.no_grad() |
| @eval_decorator |
| def edit_beta(self, |
| conds, |
| conds_og, |
| tokens, |
| m_lens, |
| cond_scale: int, |
| force_mask=False, |
| ): |
|
|
| device = next(self.parameters()).device |
| seq_len = tokens.shape[1] |
|
|
| if self.cond_mode == 'text': |
| with torch.no_grad(): |
| cond_vector = self.encode_text(conds) |
| if conds_og is not None: |
| cond_vector_og = self.encode_text(conds_og) |
| else: |
| cond_vector_og = None |
| elif self.cond_mode == 'action': |
| cond_vector = self.enc_action(conds).to(device) |
| if conds_og is not None: |
| cond_vector_og = self.enc_action(conds_og).to(device) |
| else: |
| cond_vector_og = None |
| else: |
| raise NotImplementedError("Unsupported condition mode!!!") |
|
|
| padding_mask = ~lengths_to_mask(m_lens, seq_len) |
|
|
| |
| ids = torch.where(padding_mask, self.pad_id, tokens) |
|
|
| ''' |
| Preparing input |
| ''' |
| |
| logits = self.forward_with_cond_scale(ids, |
| cond_vector=cond_vector, |
| cond_vector_neg=cond_vector_og, |
| padding_mask=padding_mask, |
| cond_scale=cond_scale, |
| force_mask=force_mask) |
|
|
| logits = logits.permute(0, 2, 1) |
|
|
| ''' |
| Updating scores |
| ''' |
| probs_without_temperature = logits.softmax(dim=-1) |
| tokens[tokens == -1] = 0 |
| og_tokens_scores = probs_without_temperature.gather(2, tokens.unsqueeze(dim=-1)) |
| og_tokens_scores = og_tokens_scores.squeeze(-1) |
|
|
| return og_tokens_scores |
|
|
|
|
| class ResidualTransformer(nn.Module): |
| def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, cond_drop_prob=0.1, |
| num_heads=4, dropout=0.1, clip_dim=512, shared_codebook=False, share_weight=False, |
| clip_version=None, opt=None, **kargs): |
| super(ResidualTransformer, self).__init__() |
| print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}') |
|
|
| |
|
|
| self.code_dim = code_dim |
| self.latent_dim = latent_dim |
| self.clip_dim = clip_dim |
| self.dropout = dropout |
| self.opt = opt |
|
|
| self.cond_mode = cond_mode |
| |
|
|
| if self.cond_mode == 'action': |
| assert 'num_actions' in kargs |
| self.num_actions = kargs.get('num_actions', 1) |
| self.cond_drop_prob = cond_drop_prob |
|
|
| ''' |
| Preparing Networks |
| ''' |
| self.input_process = InputProcess(self.code_dim, self.latent_dim) |
| self.position_enc = PositionalEncoding(self.latent_dim, self.dropout) |
|
|
| seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, |
| nhead=num_heads, |
| dim_feedforward=ff_size, |
| dropout=dropout, |
| activation='gelu') |
|
|
| self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, |
| num_layers=num_layers) |
|
|
| self.encode_quant = partial(F.one_hot, num_classes=self.opt.num_quantizers) |
| self.encode_action = partial(F.one_hot, num_classes=self.num_actions) |
|
|
| self.quant_emb = nn.Linear(self.opt.num_quantizers, self.latent_dim) |
| |
| if self.cond_mode == 'text': |
| self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim) |
| elif self.cond_mode == 'action': |
| self.cond_emb = nn.Linear(self.num_actions, self.latent_dim) |
| else: |
| raise KeyError("Unsupported condition mode!!!") |
|
|
|
|
| _num_tokens = opt.num_tokens + 1 |
| self.pad_id = opt.num_tokens |
|
|
| |
| self.output_process = OutputProcess(out_feats=code_dim, latent_dim=latent_dim) |
|
|
| if shared_codebook: |
| token_embed = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim))) |
| self.token_embed_weight = token_embed.expand(opt.num_quantizers-1, _num_tokens, code_dim) |
| if share_weight: |
| self.output_proj_weight = self.token_embed_weight |
| self.output_proj_bias = None |
| else: |
| output_proj = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim))) |
| output_bias = nn.Parameter(torch.zeros(size=(_num_tokens,))) |
| |
| self.output_proj_weight = output_proj.expand(opt.num_quantizers-1, _num_tokens, code_dim) |
| self.output_proj_bias = output_bias.expand(opt.num_quantizers-1, _num_tokens) |
|
|
| else: |
| if share_weight: |
| self.embed_proj_shared_weight = nn.Parameter(torch.normal(mean=0, std=0.02, size=(opt.num_quantizers - 2, _num_tokens, code_dim))) |
| self.token_embed_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim))) |
| self.output_proj_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim))) |
| self.output_proj_bias = None |
| self.registered = False |
| else: |
| output_proj_weight = torch.normal(mean=0, std=0.02, |
| size=(opt.num_quantizers - 1, _num_tokens, code_dim)) |
|
|
| self.output_proj_weight = nn.Parameter(output_proj_weight) |
| self.output_proj_bias = nn.Parameter(torch.zeros(size=(opt.num_quantizers, _num_tokens))) |
| token_embed_weight = torch.normal(mean=0, std=0.02, |
| size=(opt.num_quantizers - 1, _num_tokens, code_dim)) |
| self.token_embed_weight = nn.Parameter(token_embed_weight) |
|
|
| self.apply(self.__init_weights) |
| self.shared_codebook = shared_codebook |
| self.share_weight = share_weight |
|
|
| if self.cond_mode == 'text': |
| print('Loading CLIP...') |
| self.clip_version = clip_version |
| self.clip_model = self.load_and_freeze_clip(clip_version) |
|
|
| |
|
|
| def mask_cond(self, cond, force_mask=False): |
| bs, d = cond.shape |
| if force_mask: |
| return torch.zeros_like(cond) |
| elif self.training and self.cond_drop_prob > 0.: |
| mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1) |
| return cond * (1. - mask) |
| else: |
| return cond |
|
|
| def __init_weights(self, module): |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| def parameters_wo_clip(self): |
| return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] |
|
|
| def load_and_freeze_clip(self, clip_version): |
| clip_model, clip_preprocess = clip.load(clip_version, device='cpu', |
| jit=False) |
| |
| clip.model.convert_weights( |
| clip_model) |
| |
|
|
| |
| clip_model.eval() |
| for p in clip_model.parameters(): |
| p.requires_grad = False |
|
|
| return clip_model |
|
|
| def encode_text(self, raw_text): |
| device = next(self.parameters()).device |
| text = clip.tokenize(raw_text, truncate=True).to(device) |
| feat_clip_text = self.clip_model.encode_text(text).float() |
| return feat_clip_text |
|
|
|
|
| def q_schedule(self, bs, low, high): |
| noise = uniform((bs,), device=self.opt.device) |
| schedule = 1 - cosine_schedule(noise) |
| return torch.round(schedule * (high - low)) + low |
|
|
| def process_embed_proj_weight(self): |
| if self.share_weight and (not self.shared_codebook): |
| |
| self.output_proj_weight = torch.cat([self.embed_proj_shared_weight, self.output_proj_weight_], dim=0) |
| self.token_embed_weight = torch.cat([self.token_embed_weight_, self.embed_proj_shared_weight], dim=0) |
| |
|
|
| def output_project(self, logits, qids): |
| ''' |
| :logits: (bs, code_dim, seqlen) |
| :qids: (bs) |
| |
| :return: |
| -logits (bs, ntoken, seqlen) |
| ''' |
| |
| output_proj_weight = self.output_proj_weight[qids] |
| |
| output_proj_bias = None if self.output_proj_bias is None else self.output_proj_bias[qids] |
|
|
| output = torch.einsum('bnc, bcs->bns', output_proj_weight, logits) |
| if output_proj_bias is not None: |
| output += output + output_proj_bias.unsqueeze(-1) |
| return output |
|
|
|
|
|
|
| def trans_forward(self, motion_codes, qids, cond, padding_mask, force_mask=False): |
| ''' |
| :param motion_codes: (b, seqlen, d) |
| :padding_mask: (b, seqlen), all pad positions are TRUE else FALSE |
| :param qids: (b), quantizer layer ids |
| :param cond: (b, embed_dim) for text, (b, num_actions) for action |
| :return: |
| -logits: (b, num_token, seqlen) |
| ''' |
| cond = self.mask_cond(cond, force_mask=force_mask) |
|
|
| |
| x = self.input_process(motion_codes) |
|
|
| |
| q_onehot = self.encode_quant(qids).float().to(x.device) |
|
|
| q_emb = self.quant_emb(q_onehot).unsqueeze(0) |
| cond = self.cond_emb(cond).unsqueeze(0) |
|
|
| x = self.position_enc(x) |
| xseq = torch.cat([cond, q_emb, x], dim=0) |
|
|
| padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:2]), padding_mask], dim=1) |
| output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[2:] |
| logits = self.output_process(output) |
| return logits |
|
|
| def forward_with_cond_scale(self, |
| motion_codes, |
| q_id, |
| cond_vector, |
| padding_mask, |
| cond_scale=3, |
| force_mask=False): |
| bs = motion_codes.shape[0] |
| |
| qids = torch.full((bs,), q_id, dtype=torch.long, device=motion_codes.device) |
| if force_mask: |
| logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True) |
| logits = self.output_project(logits, qids-1) |
| return logits |
|
|
| logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask) |
| logits = self.output_project(logits, qids-1) |
| if cond_scale == 1: |
| return logits |
|
|
| aux_logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True) |
| aux_logits = self.output_project(aux_logits, qids-1) |
|
|
| scaled_logits = aux_logits + (logits - aux_logits) * cond_scale |
| return scaled_logits |
|
|
| def forward(self, all_indices, y, m_lens): |
| ''' |
| :param all_indices: (b, n, q) |
| :param y: raw text for cond_mode=text, (b, ) for cond_mode=action |
| :m_lens: (b,) |
| :return: |
| ''' |
|
|
| self.process_embed_proj_weight() |
|
|
| bs, ntokens, num_quant_layers = all_indices.shape |
| device = all_indices.device |
|
|
| |
| non_pad_mask = lengths_to_mask(m_lens, ntokens) |
|
|
| q_non_pad_mask = repeat(non_pad_mask, 'b n -> b n q', q=num_quant_layers) |
| all_indices = torch.where(q_non_pad_mask, all_indices, self.pad_id) |
|
|
| |
| active_q_layers = q_schedule(bs, low=1, high=num_quant_layers, device=device) |
|
|
| |
| token_embed = repeat(self.token_embed_weight, 'q c d-> b c d q', b=bs) |
| gather_indices = repeat(all_indices[..., :-1], 'b n q -> b n d q', d=token_embed.shape[2]) |
| |
| all_codes = token_embed.gather(1, gather_indices) |
|
|
| cumsum_codes = torch.cumsum(all_codes, dim=-1) |
|
|
| active_indices = all_indices[torch.arange(bs), :, active_q_layers] |
| history_sum = cumsum_codes[torch.arange(bs), :, :, active_q_layers - 1] |
|
|
| force_mask = False |
| if self.cond_mode == 'text': |
| with torch.no_grad(): |
| cond_vector = self.encode_text(y) |
| elif self.cond_mode == 'action': |
| cond_vector = self.enc_action(y).to(device).float() |
| elif self.cond_mode == 'uncond': |
| cond_vector = torch.zeros(bs, self.latent_dim).float().to(device) |
| force_mask = True |
| else: |
| raise NotImplementedError("Unsupported condition mode!!!") |
|
|
| logits = self.trans_forward(history_sum, active_q_layers, cond_vector, ~non_pad_mask, force_mask) |
| logits = self.output_project(logits, active_q_layers-1) |
| ce_loss, pred_id, acc = cal_performance(logits, active_indices, ignore_index=self.pad_id) |
|
|
| return ce_loss, pred_id, acc |
|
|
| @torch.no_grad() |
| @eval_decorator |
| def generate(self, |
| motion_ids, |
| conds, |
| m_lens, |
| temperature=1, |
| topk_filter_thres=0.9, |
| cond_scale=2, |
| num_res_layers=-1, |
| ): |
|
|
| |
| |
| self.process_embed_proj_weight() |
|
|
| device = next(self.parameters()).device |
| seq_len = motion_ids.shape[1] |
| batch_size = len(conds) |
|
|
| if self.cond_mode == 'text': |
| with torch.no_grad(): |
| cond_vector = self.encode_text(conds) |
| elif self.cond_mode == 'action': |
| cond_vector = self.enc_action(conds).to(device) |
| elif self.cond_mode == 'uncond': |
| cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) |
| else: |
| raise NotImplementedError("Unsupported condition mode!!!") |
|
|
| |
| |
| |
|
|
| |
| padding_mask = ~lengths_to_mask(m_lens, seq_len) |
| |
| motion_ids = torch.where(padding_mask, self.pad_id, motion_ids) |
| all_indices = [motion_ids] |
| history_sum = 0 |
| num_quant_layers = self.opt.num_quantizers if num_res_layers==-1 else num_res_layers+1 |
|
|
| for i in range(1, num_quant_layers): |
| |
| |
| |
| token_embed = self.token_embed_weight[i-1] |
| token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size) |
| gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) |
| history_sum += token_embed.gather(1, gathered_ids) |
|
|
| logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale) |
| |
|
|
| logits = logits.permute(0, 2, 1) |
| |
| filtered_logits = top_k(logits, topk_filter_thres, dim=-1) |
|
|
| pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) |
|
|
| |
| |
| |
| |
|
|
| ids = torch.where(padding_mask, self.pad_id, pred_ids) |
|
|
| motion_ids = ids |
| all_indices.append(ids) |
|
|
| all_indices = torch.stack(all_indices, dim=-1) |
| |
| |
| all_indices = torch.where(all_indices==self.pad_id, -1, all_indices) |
| |
| return all_indices |
|
|
| @torch.no_grad() |
| @eval_decorator |
| def edit(self, |
| motion_ids, |
| conds, |
| m_lens, |
| temperature=1, |
| topk_filter_thres=0.9, |
| cond_scale=2 |
| ): |
|
|
| |
| |
| self.process_embed_proj_weight() |
|
|
| device = next(self.parameters()).device |
| seq_len = motion_ids.shape[1] |
| batch_size = len(conds) |
|
|
| if self.cond_mode == 'text': |
| with torch.no_grad(): |
| cond_vector = self.encode_text(conds) |
| elif self.cond_mode == 'action': |
| cond_vector = self.enc_action(conds).to(device) |
| elif self.cond_mode == 'uncond': |
| cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device) |
| else: |
| raise NotImplementedError("Unsupported condition mode!!!") |
|
|
| |
| |
| |
|
|
| |
| padding_mask = ~lengths_to_mask(m_lens, seq_len) |
| |
| motion_ids = torch.where(padding_mask, self.pad_id, motion_ids) |
| all_indices = [motion_ids] |
| history_sum = 0 |
|
|
| for i in range(1, self.opt.num_quantizers): |
| |
| |
| |
| token_embed = self.token_embed_weight[i-1] |
| token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size) |
| gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1]) |
| history_sum += token_embed.gather(1, gathered_ids) |
|
|
| logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale) |
| |
|
|
| logits = logits.permute(0, 2, 1) |
| |
| filtered_logits = top_k(logits, topk_filter_thres, dim=-1) |
|
|
| pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) |
|
|
| |
| |
| |
| |
|
|
| ids = torch.where(padding_mask, self.pad_id, pred_ids) |
|
|
| motion_ids = ids |
| all_indices.append(ids) |
|
|
| all_indices = torch.stack(all_indices, dim=-1) |
| |
| |
| all_indices = torch.where(all_indices==self.pad_id, -1, all_indices) |
| |
| return all_indices |