| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from .kan import KANLayer
|
|
|
|
|
| class PositionalEncoding(nn.Module):
|
| def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
| """
|
| Standard positional encoding with Sin/Cos functions + LayerNorm to preserve
|
| temporal relationships between frames throughtout sequence-modeling.
|
| """
|
| super(PositionalEncoding, self).__init__()
|
| self.dropout = nn.Dropout(p=dropout)
|
| self.d_model = d_model
|
|
|
|
|
| pe = torch.zeros(max_len, d_model)
|
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
| pe[:, 0::2] = torch.sin(position * div_term)
|
| pe[:, 1::2] = torch.cos(position * div_term)
|
| pe = pe.unsqueeze(1)
|
| self.register_buffer("pe", pe)
|
| self.norm_pe = nn.LayerNorm(d_model)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| x: Input tensor of shape (seq_len, batch_size, d_model)
|
| Returns:
|
| Tensor with positional encodings added and normalized
|
| """
|
| seq_len = x.size(0)
|
| x2 = x + self.pe[:seq_len, :]
|
| x2 = self.norm_pe(x2)
|
| return self.dropout(x2)
|
|
|
|
|
| class Encoder_TRANSFORMER(nn.Module):
|
| """
|
| Encoder module using Transformer architecture with KAN layers.
|
| Key components:
|
| - KANLayer which eplaces linear projections with learnable 1D splines;
|
| - Transformer Encoder processing temporal dependencies.
|
| """
|
| def __init__(
|
| self,
|
| modeltype,
|
| njoints: int,
|
| nfeats: int,
|
| num_frames: int,
|
| num_classes: int,
|
| translation,
|
| pose_rep,
|
| glob,
|
| glob_rot,
|
| latent_dim: int = 256,
|
| ff_size: int = 1024,
|
| num_layers: int = 4,
|
| num_heads: int = 4,
|
| dropout: float = 0.1,
|
| activation: str = "gelu",
|
| **kargs
|
| ):
|
| super().__init__()
|
| self.njoints = njoints
|
| self.nfeats = nfeats
|
| self.num_frames = num_frames
|
| self.num_classes = num_classes
|
| self.pose_rep = pose_rep
|
| self.glob = glob
|
| self.glob_rot = glob_rot
|
| self.translation = translation
|
|
|
| self.latent_dim = latent_dim
|
| self.ff_size = ff_size
|
| self.num_layers = num_layers
|
| self.num_heads = num_heads
|
| self.dropout = dropout
|
| self.activation = activation
|
|
|
| self.input_feats = self.njoints * self.nfeats
|
|
|
|
|
| self.muQuery = nn.Parameter(torch.randn(1, self.latent_dim))
|
| self.sigmaQuery = nn.Parameter(torch.randn(1, self.latent_dim))
|
|
|
|
|
|
|
|
|
|
|
| self.skelEmbedding = KANLayer(self.input_feats, self.latent_dim)
|
|
|
|
|
| self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
|
|
|
|
|
| encoder_layer = nn.TransformerEncoderLayer(
|
| d_model=self.latent_dim,
|
| nhead=self.num_heads,
|
| dim_feedforward=self.ff_size,
|
| dropout=self.dropout,
|
| activation=self.activation
|
| )
|
| self.seqTransEncoder = nn.TransformerEncoder(encoder_layer, num_layers=self.num_layers)
|
| self.encoder_norm = nn.LayerNorm(self.latent_dim)
|
|
|
| def forward(self, batch: dict) -> dict:
|
| """
|
| batch["x"]: (batch, njoints, nfeats, nframes)
|
| batch["y"]: (batch,) — classes (if none, then == 0)
|
| batch["mask"]: (batch, nframes) — bool-mask of actual frames
|
| """
|
| x, y, mask = batch["x"], batch["y"], batch["mask"]
|
| bs, nj, nf, nf2 = x.shape
|
| assert nf2 == self.num_frames, "Frame dimension mismatch"
|
|
|
|
|
| x_seq = x.permute(3, 0, 1, 2).reshape(self.num_frames, bs, self.input_feats)
|
|
|
|
|
| x_emb = self.skelEmbedding(x_seq)
|
|
|
|
|
| if y is None:
|
| y = torch.zeros(bs, dtype=torch.long, device=x.device)
|
| else:
|
| y = y.clamp(0, self.num_classes - 1)
|
|
|
|
|
| mu_init = self.muQuery.expand(bs, -1)
|
| sigma_init = self.sigmaQuery.expand(bs, -1)
|
|
|
|
|
| mu_init = mu_init.unsqueeze(0)
|
| sigma_init = sigma_init.unsqueeze(0)
|
| xcat = torch.cat((mu_init, sigma_init, x_emb), dim=0)
|
|
|
|
|
| mu_sigma_mask = torch.ones((bs, 2), dtype=torch.bool, device=x.device)
|
| mask_seq = torch.cat((mu_sigma_mask, mask), dim=1)
|
|
|
|
|
| xcat_pe = self.sequence_pos_encoder(xcat)
|
|
|
|
|
| encoded = self.seqTransEncoder(
|
| xcat_pe,
|
| src_key_padding_mask=~mask_seq
|
| )
|
|
|
|
|
| encoded = self.encoder_norm(encoded)
|
|
|
|
|
| mu = encoded[0]
|
| logvar = encoded[1]
|
|
|
|
|
| std = torch.exp(0.5 * logvar)
|
| eps = torch.randn_like(std)
|
| z = mu + eps * std
|
|
|
| return {"mu": mu, "logvar": logvar, "z": z}
|
|
|
|
|
| class Decoder_TRANSFORMER(nn.Module):
|
| """
|
| Decoder module using Transformer architecture with KAN-layer:
|
| - KANLayer: Final projection layer for skeleton reconstruction
|
| - Transformer Decoder: Autoregressive generation of sequences
|
| """
|
| def __init__(
|
| self,
|
| modeltype,
|
| njoints: int,
|
| nfeats: int,
|
| num_frames: int,
|
| num_classes: int,
|
| translation,
|
| pose_rep,
|
| glob,
|
| glob_rot,
|
| latent_dim: int = 256,
|
| ff_size: int = 1024,
|
| num_layers: int = 4,
|
| num_heads: int = 4,
|
| dropout: float = 0.1,
|
| activation: str = "gelu",
|
| **kargs
|
| ):
|
| super().__init__()
|
|
|
| self.njoints = njoints
|
| self.nfeats = nfeats
|
| self.num_frames = num_frames
|
| self.num_classes = num_classes
|
| self.pose_rep = pose_rep
|
| self.glob = glob
|
| self.glob_rot = glob_rot
|
| self.translation = translation
|
|
|
| self.latent_dim = latent_dim
|
| self.ff_size = ff_size
|
| self.num_layers = num_layers
|
| self.num_heads = num_heads
|
| self.dropout = dropout
|
| self.activation = activation
|
|
|
| self.input_feats = self.njoints * self.nfeats
|
|
|
|
|
| self.actionBiases = nn.Parameter(torch.randn(1, self.latent_dim))
|
|
|
|
|
| self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
|
|
|
|
|
| decoder_layer = nn.TransformerDecoderLayer(
|
| d_model=self.latent_dim,
|
| nhead=self.num_heads,
|
| dim_feedforward=self.ff_size,
|
| dropout=self.dropout,
|
| activation=self.activation
|
| )
|
| self.seqTransDecoder = nn.TransformerDecoder(decoder_layer, num_layers=self.num_layers)
|
| self.decoder_norm = nn.LayerNorm(self.latent_dim)
|
|
|
|
|
|
|
|
|
| self.finallayer = KANLayer(self.latent_dim, self.input_feats)
|
|
|
| def forward(self, batch: dict, use_text_emb: bool = False) -> dict:
|
| """
|
| Forward pass for the decoder.
|
| Args:
|
| batch: Dictionary containing latent codes and metadata
|
| use_text_emb: Whether to use text embeddings instead of latent codes
|
| Returns:
|
| Dictionary with generated output
|
| """
|
| z = batch["z"]
|
| y = batch["y"]
|
| mask = batch["mask"]
|
| lengths = batch.get("lengths", None)
|
| bs, nframes = mask.shape
|
| nj, nf = self.njoints, self.nfeats
|
|
|
|
|
| if use_text_emb:
|
| z = batch["clip_text_emb"]
|
|
|
|
|
| z = F.layer_norm(z, (self.latent_dim,))
|
| z = z.unsqueeze(0)
|
|
|
|
|
| timequeries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
|
|
|
|
|
| timequeries_pe = self.sequence_pos_encoder(timequeries)
|
|
|
|
|
| if mask.dtype != torch.bool:
|
| mask = mask.bool()
|
|
|
|
|
| dec_out = self.seqTransDecoder(
|
| tgt=timequeries_pe,
|
| memory=z,
|
| tgt_key_padding_mask=~mask
|
| )
|
|
|
|
|
| dec_out = self.decoder_norm(dec_out)
|
|
|
|
|
| skel_feats = self.finallayer(dec_out)
|
| skel_feats = skel_feats.view(nframes, bs, nj, nf)
|
|
|
|
|
| mask_t = mask.T
|
| skel_feats[~mask_t] = 0.0
|
|
|
|
|
| output = skel_feats.permute(1, 2, 3, 0).contiguous()
|
|
|
| if use_text_emb:
|
| batch["txt_output"] = output
|
| else:
|
| batch["output"] = output
|
|
|
| return batch |