|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.nn import functional as F
|
|
|
from torch.distributions import Categorical
|
|
|
import models.pos_encoding as pos_encoding
|
|
|
|
|
|
class Text2Motion_Transformer(nn.Module):
|
|
|
|
|
|
def __init__(self,
|
|
|
num_vq=1024,
|
|
|
embed_dim=512,
|
|
|
clip_dim=512,
|
|
|
block_size=16,
|
|
|
num_layers=2,
|
|
|
n_head=8,
|
|
|
drop_out_rate=0.1,
|
|
|
fc_rate=4):
|
|
|
super().__init__()
|
|
|
self.trans_base = CrossCondTransBase(num_vq, embed_dim, clip_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
|
|
|
self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
|
|
|
self.block_size = block_size
|
|
|
self.num_vq = num_vq
|
|
|
|
|
|
def get_block_size(self):
|
|
|
return self.block_size
|
|
|
|
|
|
def forward(self, idxs, clip_feature):
|
|
|
feat = self.trans_base(idxs, clip_feature)
|
|
|
logits = self.trans_head(feat)
|
|
|
return logits
|
|
|
|
|
|
def sample(self, clip_feature, if_categorial=False):
|
|
|
for k in range(self.block_size):
|
|
|
if k == 0:
|
|
|
x = []
|
|
|
else:
|
|
|
x = xs
|
|
|
logits = self.forward(x, clip_feature)
|
|
|
logits = logits[:, -1, :]
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
if if_categorial:
|
|
|
dist = Categorical(probs)
|
|
|
idx = dist.sample()
|
|
|
if idx == self.num_vq:
|
|
|
break
|
|
|
idx = idx.unsqueeze(-1)
|
|
|
else:
|
|
|
_, idx = torch.topk(probs, k=1, dim=-1)
|
|
|
if idx[0] == self.num_vq:
|
|
|
break
|
|
|
|
|
|
if k == 0:
|
|
|
xs = idx
|
|
|
else:
|
|
|
xs = torch.cat((xs, idx), dim=1)
|
|
|
|
|
|
if k == self.block_size - 1:
|
|
|
return xs[:, :-1]
|
|
|
return xs
|
|
|
|
|
|
class CausalCrossConditionalSelfAttention(nn.Module):
|
|
|
|
|
|
def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1):
|
|
|
super().__init__()
|
|
|
assert embed_dim % 8 == 0
|
|
|
|
|
|
self.key = nn.Linear(embed_dim, embed_dim)
|
|
|
self.query = nn.Linear(embed_dim, embed_dim)
|
|
|
self.value = nn.Linear(embed_dim, embed_dim)
|
|
|
|
|
|
self.attn_drop = nn.Dropout(drop_out_rate)
|
|
|
self.resid_drop = nn.Dropout(drop_out_rate)
|
|
|
|
|
|
self.proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
|
|
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
|
|
|
self.n_head = n_head
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, T, C = x.size()
|
|
|
|
|
|
|
|
|
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
|
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
|
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
|
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
|
|
att = F.softmax(att, dim=-1)
|
|
|
att = self.attn_drop(att)
|
|
|
y = att @ v
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
|
|
|
|
|
|
|
|
y = self.resid_drop(self.proj(y))
|
|
|
return y
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
|
def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4):
|
|
|
super().__init__()
|
|
|
self.ln1 = nn.LayerNorm(embed_dim)
|
|
|
self.ln2 = nn.LayerNorm(embed_dim)
|
|
|
self.attn = CausalCrossConditionalSelfAttention(embed_dim, block_size, n_head, drop_out_rate)
|
|
|
self.mlp = nn.Sequential(
|
|
|
nn.Linear(embed_dim, fc_rate * embed_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(fc_rate * embed_dim, embed_dim),
|
|
|
nn.Dropout(drop_out_rate),
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x + self.attn(self.ln1(x))
|
|
|
x = x + self.mlp(self.ln2(x))
|
|
|
return x
|
|
|
|
|
|
class CrossCondTransBase(nn.Module):
|
|
|
|
|
|
def __init__(self,
|
|
|
num_vq=1024,
|
|
|
embed_dim=512,
|
|
|
clip_dim=512,
|
|
|
block_size=16,
|
|
|
num_layers=2,
|
|
|
n_head=8,
|
|
|
drop_out_rate=0.1,
|
|
|
fc_rate=4):
|
|
|
super().__init__()
|
|
|
self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)
|
|
|
self.cond_emb = nn.Linear(clip_dim, embed_dim)
|
|
|
self.pos_embedding = nn.Embedding(block_size, embed_dim)
|
|
|
self.drop = nn.Dropout(drop_out_rate)
|
|
|
|
|
|
self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
|
|
|
self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
|
|
|
|
|
|
self.block_size = block_size
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def get_block_size(self):
|
|
|
return self.block_size
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
module.bias.data.zero_()
|
|
|
elif isinstance(module, nn.LayerNorm):
|
|
|
module.bias.data.zero_()
|
|
|
module.weight.data.fill_(1.0)
|
|
|
|
|
|
def forward(self, idx, clip_feature):
|
|
|
if len(idx) == 0:
|
|
|
token_embeddings = self.cond_emb(clip_feature).unsqueeze(1)
|
|
|
else:
|
|
|
b, t = idx.size()
|
|
|
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
|
|
|
|
|
token_embeddings = self.tok_emb(idx)
|
|
|
token_embeddings = torch.cat([self.cond_emb(clip_feature).unsqueeze(1), token_embeddings], dim=1)
|
|
|
|
|
|
x = self.pos_embed(token_embeddings)
|
|
|
x = self.blocks(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class CrossCondTransHead(nn.Module):
|
|
|
|
|
|
def __init__(self,
|
|
|
num_vq=1024,
|
|
|
embed_dim=512,
|
|
|
block_size=16,
|
|
|
num_layers=2,
|
|
|
n_head=8,
|
|
|
drop_out_rate=0.1,
|
|
|
fc_rate=4):
|
|
|
super().__init__()
|
|
|
|
|
|
self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate) for _ in range(num_layers)])
|
|
|
self.ln_f = nn.LayerNorm(embed_dim)
|
|
|
self.head = nn.Linear(embed_dim, num_vq + 1, bias=False)
|
|
|
self.block_size = block_size
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def get_block_size(self):
|
|
|
return self.block_size
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
module.bias.data.zero_()
|
|
|
elif isinstance(module, nn.LayerNorm):
|
|
|
module.bias.data.zero_()
|
|
|
module.weight.data.fill_(1.0)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.blocks(x)
|
|
|
x = self.ln_f(x)
|
|
|
logits = self.head(x)
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|