| | import math |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from torch.distributions import Categorical |
| | import models.pos_encoding as pos_encoding |
| | import numpy as np |
| |
|
| |
|
| |
|
| | class Text2Motion_Transformer(nn.Module): |
| |
|
| | def __init__(self, |
| | num_vq=1024, |
| | embed_dim=512, |
| | clip_dim=512, |
| | block_size=16, |
| | num_layers=2, |
| | n_head=8, |
| | drop_out_rate=0.1, |
| | fc_rate=4, |
| | |
| | ): |
| | super().__init__() |
| | self.trans_base = CrossCondTransBase(num_vq, embed_dim, clip_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate) |
| | self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate) |
| | self.block_size = block_size |
| | self.num_vq = num_vq |
| | |
| | def get_block_size(self): |
| | return self.block_size |
| |
|
| | def forward(self, idxs, clip_feature): |
| | feat = self.trans_base(idxs, clip_feature) |
| | logits = self.trans_head(feat) |
| | return logits |
| |
|
| | def sample(self, clip_feature, if_categorial=False,att=False): |
| | for k in range(self.block_size): |
| | if k == 0: |
| | x = [] |
| | logits = self.forward(x, clip_feature) |
| | if att==True: |
| | return self.trans_base.blocks[0].get_attention_weights() |
| | |
| | logits = logits[:, -1, :] |
| | probs = F.softmax(logits, dim=-1) |
| | |
| | else: |
| | x = xs |
| | logits = self.forward(x, clip_feature) |
| | logits = logits[:, -1, :] |
| | probs = F.softmax(logits, dim=-1) |
| | if if_categorial: |
| | dist = Categorical(probs) |
| | idx = dist.sample() |
| | if idx == self.num_vq: |
| | break |
| | idx = idx.unsqueeze(-1) |
| | else: |
| | _, idx = torch.topk(probs, k=1, dim=-1) |
| | if idx[0] == self.num_vq: |
| | break |
| | |
| | if k == 0: |
| | xs = idx |
| | else: |
| | xs = torch.cat((xs, idx), dim=1) |
| | |
| | if k == self.block_size - 1: |
| | return xs[:, :-1] |
| | return xs |
| |
|
| | class CausalCrossConditionalSelfAttention(nn.Module): |
| |
|
| | def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1): |
| | super().__init__() |
| | assert embed_dim % 8 == 0 |
| | |
| | self.key = nn.Linear(embed_dim, embed_dim) |
| | self.query = nn.Linear(embed_dim, embed_dim) |
| | self.value = nn.Linear(embed_dim, embed_dim) |
| |
|
| | self.attn_drop = nn.Dropout(drop_out_rate) |
| | self.resid_drop = nn.Dropout(drop_out_rate) |
| |
|
| | self.proj = nn.Linear(embed_dim, embed_dim) |
| | |
| | self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)) |
| | self.n_head = n_head |
| | self.att=None |
| | def forward(self, x): |
| | B, T, C = x.size() |
| |
|
| | |
| | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| | |
| | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
| | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) |
| | att = F.softmax(att, dim=-1) |
| | att = self.attn_drop(att) |
| | y = att @ v |
| | y = y.transpose(1, 2).contiguous().view(B, T, C) |
| | self.att=att |
| | |
| | y = self.resid_drop(self.proj(y)) |
| |
|
| | return y |
| | |
| | def get_attention_weights(self): |
| | return self.att |
| |
|
| | class Block(nn.Module): |
| |
|
| | def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4,num_layers=-1,num=None): |
| | super().__init__() |
| | self.num_layers=num_layers |
| | self.num=num |
| | self.attn_weight=None |
| | self.ln1 = nn.LayerNorm(embed_dim) |
| | self.ln2 = nn.LayerNorm(embed_dim) |
| | self.attn = CausalCrossConditionalSelfAttention(embed_dim, block_size, n_head, drop_out_rate) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(embed_dim, fc_rate * embed_dim), |
| | nn.GELU(), |
| | nn.Linear(fc_rate * embed_dim, embed_dim), |
| | nn.Dropout(drop_out_rate), |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x + self.attn(self.ln1(x)) |
| | if self.num==0: |
| | self.attn_weight = self.attn.get_attention_weights() |
| | x = x + self.mlp(self.ln2(x)) |
| | return x |
| | def get_attention_weights(self): |
| | return self.attn_weight |
| |
|
| | class CrossCondTransBase(nn.Module): |
| |
|
| | def __init__(self, |
| | num_vq=1024, |
| | embed_dim=512, |
| | clip_dim=512, |
| | block_size=16, |
| | num_layers=2, |
| | n_head=8, |
| | drop_out_rate=0.1, |
| | fc_rate=4, |
| | ): |
| | super().__init__() |
| |
|
| | self.tok_emb = nn.Embedding(num_vq + 2, embed_dim) |
| | self.cond_emb = nn.Linear(clip_dim, embed_dim) |
| | self.pos_embedding = nn.Embedding(block_size, embed_dim) |
| | self.drop = nn.Dropout(drop_out_rate) |
| | |
| | self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate,num=_) for _ in range(num_layers)]) |
| | self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False) |
| |
|
| | self.block_size = block_size |
| | self.first_att_weights = None |
| | self.apply(self._init_weights) |
| |
|
| | def get_block_size(self): |
| | return self.block_size |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, (nn.Linear, nn.Embedding)): |
| | module.weight.data.normal_(mean=0.0, std=0.02) |
| | if isinstance(module, nn.Linear) and module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | |
| | def forward(self, idx, clip_feature): |
| | if len(idx) == 0: |
| | token_embeddings = self.cond_emb(clip_feature).unsqueeze(1) |
| | else: |
| | b, t = idx.size() |
| | assert t <= self.block_size, "Cannot forward, model block size is exhausted." |
| | |
| | token_embeddings = self.tok_emb(idx) |
| | |
| | token_embeddings = torch.cat([self.cond_emb(clip_feature.to(torch.float32)).unsqueeze(1), token_embeddings], dim=1) |
| | |
| | x = self.pos_embed(token_embeddings) |
| | x = self.blocks(x) |
| |
|
| | return x |
| |
|
| |
|
| |
|
| |
|
| | class CrossCondTransHead(nn.Module): |
| |
|
| | def __init__(self, |
| | num_vq=1024, |
| | embed_dim=512, |
| | block_size=16, |
| | num_layers=2, |
| | n_head=8, |
| | drop_out_rate=0.1, |
| | fc_rate=4): |
| | super().__init__() |
| |
|
| | self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate,num=_) for _ in range(num_layers)]) |
| | self.ln_f = nn.LayerNorm(embed_dim) |
| | self.head = nn.Linear(embed_dim, num_vq + 1, bias=False) |
| | self.block_size = block_size |
| |
|
| | self.apply(self._init_weights) |
| |
|
| | def get_block_size(self): |
| | return self.block_size |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, (nn.Linear, nn.Embedding)): |
| | module.weight.data.normal_(mean=0.0, std=0.02) |
| | if isinstance(module, nn.Linear) and module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| |
|
| | def forward(self, x): |
| | x = self.blocks(x) |
| | x = self.ln_f(x) |
| | logits = self.head(x) |
| | return logits |
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|