GestureLSM / models /denoiser.py
Tharun156's picture
Upload 149 files
f7400bf verified
import pdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .layers.utils import *
from .layers.transformer import SpatialTemporalBlock, CrossAttentionBlock
class GestureDenoiser(nn.Module):
def __init__(self,
input_dim=128,
latent_dim=256,
ff_size=1024,
num_layers=8,
num_heads=4,
dropout=0.1,
activation="gelu",
n_seed=8,
flip_sin_to_cos= True,
freq_shift = 0,
cond_proj_dim=None,
use_exp=False,
seq_len=32,
embed_context_multiplier=4,
):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.activation = activation
self.use_exp = use_exp
self.joint_num = 3 if not self.use_exp else 4
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
self.cross_attn_blocks = nn.ModuleList([
CrossAttentionBlock(dim=self.latent_dim*self.joint_num,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程
for _ in range(3)])
self.mytimmblocks = nn.ModuleList([
SpatialTemporalBlock(dim=self.latent_dim,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程
for _ in range(self.num_layers)])
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
self.n_seed = n_seed
self.seq_len = seq_len
self.embed_context_multiplier = embed_context_multiplier
self.embed_text = nn.Linear(self.input_dim * self.joint_num * self.embed_context_multiplier, self.latent_dim)
self.output_process = OutputProcess(self.input_dim, self.latent_dim)
self.rel_pos = SinusoidalEmbeddings(self.latent_dim)
self.input_process = InputProcess(self.input_dim , self.latent_dim)
self.input_process2 = nn.Linear(self.latent_dim*2, self.latent_dim)
self.time_embedding = TimestepEmbedding(self.latent_dim, self.latent_dim, self.activation, cond_proj_dim=cond_proj_dim, zero_init_cond=True)
time_dim = self.latent_dim
self.time_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift)
if cond_proj_dim is not None:
self.cond_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift)
# Null condition embedding for classifier-free guidance
self.null_cond_embed = nn.Parameter(torch.zeros(self.seq_len, self.latent_dim*self.joint_num), requires_grad=True)
# dropout mask
def prob_mask_like(self, shape, prob, device):
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
def forward(self, x, timesteps, cond_time=None, seed=None, at_feat=None):
"""
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
timesteps: [batch_size] (int)
seed: [batch_size, njoints, nfeats]
"""
if x.shape[2] == 1:
x = x.squeeze(2)
x = x.reshape(x.shape[0], self.joint_num, -1, x.shape[2])
bs, njoints, nfeats, nframes = x.shape # [bs, 3, 128, 32]
# need to be an arrary, especially when bs is 1
# timesteps = timesteps.expand(bs).clone()
time_emb = self.time_proj(timesteps)
time_emb = time_emb.to(dtype=x.dtype)
if cond_time is not None and self.cond_proj is not None:
cond_time = cond_time.expand(bs).clone()
cond_emb = self.cond_proj(cond_time)
cond_emb = cond_emb.to(dtype=x.dtype)
emb_t = self.time_embedding(time_emb, cond_emb)
else:
emb_t = self.time_embedding(time_emb)
if self.n_seed != 0:
embed_text = self.embed_text(seed.reshape(bs, -1))
emb_seed = embed_text
xseq = self.input_process(x)
# add the seed information
embed_style_2 = (emb_seed + emb_t).unsqueeze(1).unsqueeze(2).expand(-1, self.joint_num, self.seq_len, -1) # (300, 256)
xseq = torch.cat([embed_style_2, xseq], axis=-1) # -> [88, 300, 576]
xseq = self.input_process2(xseq)
# apply the positional encoding
xseq = xseq.reshape(bs * self.joint_num, nframes, -1)
pos_emb = self.rel_pos(xseq)
xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb)
xseq = xseq.reshape(bs, self.joint_num, nframes, -1)
xseq = xseq.view(bs, self.seq_len, -1)
for block in self.cross_attn_blocks:
xseq = block(xseq, at_feat)
xseq = xseq.view(bs, njoints, self.seq_len, -1)
for block in self.mytimmblocks:
xseq = block(xseq)
output = xseq
output = self.output_process(output)
return output