GeoMotionGPT / modeling_geomotiongpt.py
zy22b's picture
Upload folder using huggingface_hub
a73d8fb verified
"""
GeoMotionGPT Model
This module contains the model implementation for GeoMotionGPT, integrating:
1. Motion Tokenizer (DVQ-GSST VQ-VAE)
2. Language Model (fine-tuned GPT-2 for motion-to-text)
Usage:
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("zy22b/GeoMotionGPT", trust_remote_code=True)
motion_tokenizer = model.motion_tokenizer
# Tokenize motion
motion_tokens = motion_tokenizer.encode(motion_features)
# Generate text
text = model.generate_from_motion(motion_tokens)
```
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Union
from transformers import PreTrainedModel, GPT2LMHeadModel, GPT2Config
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
# Handle both package and standalone imports
try:
from .configuration_geomotiongpt import GeoMotionGPTConfig
except ImportError:
from configuration_geomotiongpt import GeoMotionGPTConfig
# =====================================================
# Motion Tokenizer Components (DVQ-GSST)
# =====================================================
class Swish(nn.Module):
"""Swish activation function."""
def forward(self, x):
return x * torch.sigmoid(x)
class ResConv1DBlock(nn.Module):
"""Single residual convolution block."""
def __init__(self, n_in, n_state, dilation=1, activation='relu', norm=None):
super().__init__()
padding = dilation
self.norm = norm
if norm == "LN":
self.norm1 = nn.LayerNorm(n_in)
self.norm2 = nn.LayerNorm(n_in)
elif norm == "GN":
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
elif norm == "BN":
self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
else:
self.norm1 = nn.Identity()
self.norm2 = nn.Identity()
if activation == "relu":
self.activation1 = nn.ReLU()
self.activation2 = nn.ReLU()
elif activation == "silu":
self.activation1 = Swish()
self.activation2 = Swish()
elif activation == "gelu":
self.activation1 = nn.GELU()
self.activation2 = nn.GELU()
self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0)
def forward(self, x):
x_orig = x
if self.norm == "LN":
x = self.norm1(x.transpose(-2, -1))
x = self.activation1(x.transpose(-2, -1))
else:
x = self.norm1(x)
x = self.activation1(x)
x = self.conv1(x)
if self.norm == "LN":
x = self.norm2(x.transpose(-2, -1))
x = self.activation2(x.transpose(-2, -1))
else:
x = self.norm2(x)
x = self.activation2(x)
x = self.conv2(x)
return x + x_orig
class Resnet1D(nn.Module):
"""1D ResNet block composed of multiple ResConv1DBlocks."""
def __init__(self, n_in, n_depth, dilation_growth_rate=1,
reverse_dilation=True, activation='relu', norm=None):
super().__init__()
blocks = [
ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth,
activation=activation, norm=norm)
for depth in range(n_depth)
]
if reverse_dilation:
blocks = blocks[::-1]
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class MotionEncoder(nn.Module):
"""Encoder for motion features with temporal downsampling."""
def __init__(self, input_dim=263, hidden_dim=512, nb_code=512,
down_t=3, stride_t=2, depth=3, dilation_growth_rate=3,
activation='relu', norm=None):
super().__init__()
blocks = []
filter_t, pad_t = stride_t * 2, stride_t // 2
blocks.append(nn.Conv1d(input_dim, hidden_dim, 3, 1, 1))
blocks.append(nn.ReLU())
for _ in range(down_t):
block = nn.Sequential(
nn.Conv1d(hidden_dim, hidden_dim, filter_t, stride_t, pad_t),
Resnet1D(hidden_dim, depth, dilation_growth_rate,
reverse_dilation=False, activation=activation, norm=norm),
)
blocks.append(block)
blocks.append(nn.Conv1d(hidden_dim, nb_code, 3, 1, 1))
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class MotionDecoder(nn.Module):
"""Decoder for reconstructing motion from quantized features."""
def __init__(self, output_dim=263, hidden_dim=512, code_dim=512,
down_t=3, stride_t=2, depth=3, dilation_growth_rate=3,
activation='relu', norm=None):
super().__init__()
blocks = []
blocks.append(nn.Conv1d(code_dim, hidden_dim, 3, 1, 1))
blocks.append(nn.ReLU())
for _ in range(down_t):
block = nn.Sequential(
Resnet1D(hidden_dim, depth, dilation_growth_rate,
reverse_dilation=True, activation=activation, norm=norm),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv1d(hidden_dim, hidden_dim, 3, 1, 1)
)
blocks.append(block)
blocks.append(nn.Conv1d(hidden_dim, hidden_dim, 3, 1, 1))
blocks.append(nn.ReLU())
blocks.append(nn.Conv1d(hidden_dim, output_dim, 3, 1, 1))
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class GumbelSoftmaxQuantizer(nn.Module):
"""Gumbel-Softmax Straight-Through quantizer for VQ-VAE."""
def __init__(self, nb_code=512, code_dim=512):
super().__init__()
self.nb_code = nb_code
self.code_dim = code_dim
self.codebook = nn.Embedding(nb_code, code_dim)
nn.init.uniform_(self.codebook.weight, -1.0 / nb_code, 1.0 / nb_code)
self.tau = 0.4
def quantize(self, x):
"""Quantize encoder output to discrete indices."""
return x.argmax(dim=-1)
def dequantize(self, indices):
"""Convert indices back to embeddings."""
return self.codebook(indices)
def forward(self, x_encoder):
"""Forward pass with Gumbel-Softmax sampling."""
N, C, T = x_encoder.shape
x = x_encoder.permute(0, 2, 1).contiguous().view(-1, C)
# Gumbel-Softmax with straight-through
y_hard_st = F.gumbel_softmax(x, tau=self.tau, hard=True, dim=-1)
x_quantized = torch.matmul(y_hard_st, self.codebook.weight)
return x_quantized.view(N, T, -1).permute(0, 2, 1).contiguous()
class MotionTokenizer(nn.Module):
"""
DVQ-GSST Motion Tokenizer.
Converts continuous motion features (263-dim HumanML3D format) to discrete tokens.
Args:
config: GeoMotionGPTConfig containing motion tokenizer parameters
Example:
```python
motion = torch.randn(1, 100, 263) # (batch, time, features)
tokens = motion_tokenizer.encode(motion) # (batch, time//8)
```
"""
def __init__(self, config: GeoMotionGPTConfig):
super().__init__()
self.config = config
self.encoder = MotionEncoder(
input_dim=config.motion_input_dim,
hidden_dim=config.motion_hidden_dim,
nb_code=config.motion_vocab_size,
down_t=config.motion_down_t,
depth=config.motion_depth,
dilation_growth_rate=config.motion_dilation_growth_rate,
)
self.decoder = MotionDecoder(
output_dim=config.motion_input_dim,
hidden_dim=config.motion_hidden_dim,
code_dim=config.motion_vocab_size,
down_t=config.motion_down_t,
depth=config.motion_depth,
dilation_growth_rate=config.motion_dilation_growth_rate,
)
self.quantizer = GumbelSoftmaxQuantizer(
nb_code=config.motion_vocab_size,
code_dim=config.motion_vocab_size,
)
def encode(self, motion: torch.Tensor) -> torch.Tensor:
"""
Encode motion features to discrete tokens.
Args:
motion: Motion features of shape (batch, time, 263)
Returns:
Token indices of shape (batch, time // downsample_ratio)
"""
# (batch, time, 263) -> (batch, 263, time)
x = motion.permute(0, 2, 1).float()
# Encode
x_enc = self.encoder(x) # (batch, nb_code, time')
# (batch, nb_code, time') -> (batch, time', nb_code)
x_enc = x_enc.permute(0, 2, 1).contiguous()
N, T, C = x_enc.shape
# Get token indices
indices = self.quantizer.quantize(x_enc.view(-1, C))
return indices.view(N, T)
def decode(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Decode tokens back to motion features.
Args:
tokens: Token indices of shape (batch, time')
Returns:
Motion features of shape (batch, time, 263)
"""
# Get embeddings from tokens
x = self.quantizer.dequantize(tokens) # (batch, time', code_dim)
# (batch, time', code_dim) -> (batch, code_dim, time')
x = x.permute(0, 2, 1).contiguous()
# Decode
x_out = self.decoder(x) # (batch, 263, time)
# (batch, 263, time) -> (batch, time, 263)
return x_out.permute(0, 2, 1)
def forward(self, motion: torch.Tensor):
"""Forward pass for training (encode -> quantize -> decode)."""
x = motion.permute(0, 2, 1).float()
x_enc = self.encoder(x)
x_quant = self.quantizer(x_enc)
x_dec = self.decoder(x_quant)
return x_dec.permute(0, 2, 1)
# =====================================================
# Main GeoMotionGPT Model
# =====================================================
class GeoMotionGPTPreTrainedModel(PreTrainedModel):
"""Base class for GeoMotionGPT models."""
config_class = GeoMotionGPTConfig
base_model_prefix = "geomotiongpt"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize weights."""
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class GeoMotionGPTForCausalLM(GeoMotionGPTPreTrainedModel):
"""
GeoMotionGPT Model for motion-to-text generation.
This model combines:
1. A VQ-VAE motion tokenizer (DVQ-GSST) for converting motion to discrete tokens
2. A fine-tuned GPT-2 model for generating text from motion tokens
Example:
```python
from transformers import AutoModelForCausalLM
import torch
# Load model
model = AutoModelForCausalLM.from_pretrained(
"zy22b/GeoMotionGPT",
trust_remote_code=True
)
# Access motion tokenizer
motion_tokenizer = model.motion_tokenizer
# Tokenize motion (batch, time, 263) -> (batch, tokens)
motion = torch.randn(1, 100, 263)
motion_tokens = motion_tokenizer.encode(motion)
# Generate text from motion tokens
text = model.generate_text(motion_tokens)
```
"""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GeoMotionGPTConfig):
super().__init__(config)
# Motion tokenizer
self.motion_tokenizer = MotionTokenizer(config)
# Build GPT-2 config
gpt2_config = GPT2Config(
vocab_size=config.vocab_size,
n_positions=config.n_positions,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
n_inner=config.n_inner,
activation_function=config.activation_function,
resid_pdrop=config.resid_pdrop,
embd_pdrop=config.embd_pdrop,
attn_pdrop=config.attn_pdrop,
layer_norm_epsilon=config.layer_norm_epsilon,
initializer_range=config.initializer_range,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
)
# Language model (GPT-2)
self.language_model = GPT2LMHeadModel(gpt2_config)
# Motion token embeddings (separate from text embeddings)
mot_embed_dim = int(config.n_embd // config.n_head * config.mot_factor) * config.n_head
self.motion_embed = nn.Embedding(
config.motion_vocab_size + 3, # +3 for special tokens (BOT, EOT, PAD)
mot_embed_dim
)
self.motion_head = nn.Linear(mot_embed_dim, config.motion_vocab_size + 3, bias=False)
# Projection layers for multi-modal fusion
self.motion_to_text_proj = nn.Linear(mot_embed_dim, config.n_embd)
self.text_to_motion_proj = nn.Linear(config.n_embd, mot_embed_dim)
# Initialize weights
self.post_init()
def get_input_embeddings(self):
return self.language_model.transformer.wte
def set_input_embeddings(self, value):
self.language_model.transformer.wte = value
def get_output_embeddings(self):
return self.language_model.lm_head
def set_output_embeddings(self, new_embeddings):
self.language_model.lm_head = new_embeddings
def encode_motion(self, motion: torch.Tensor) -> torch.Tensor:
"""
Encode motion features to discrete tokens.
Args:
motion: Motion features of shape (batch, time, 263)
Returns:
Token indices of shape (batch, time // 8)
"""
return self.motion_tokenizer.encode(motion)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
):
"""
Forward pass through the language model.
For motion-to-text generation, use the `generate_text` method instead.
"""
return self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
"""Prepare inputs for text generation."""
return self.language_model.prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, **kwargs
)
@torch.no_grad()
def generate_text(
self,
motion_tokens: torch.Tensor,
max_new_tokens: int = 128,
num_beams: int = 4,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True,
**kwargs
) -> List[str]:
"""
Generate text descriptions from motion tokens.
Args:
motion_tokens: Motion token indices of shape (batch, seq_len)
max_new_tokens: Maximum number of new tokens to generate
num_beams: Number of beams for beam search
temperature: Sampling temperature
top_p: Top-p sampling parameter
do_sample: Whether to use sampling
Returns:
List of generated text strings
"""
device = motion_tokens.device
batch_size = motion_tokens.shape[0]
# Offset motion tokens (they come after text tokens)
motion_offset = self.config.text_vocab_size
input_ids = motion_tokens + motion_offset
# Add BOS token at the start
bos_tokens = torch.full(
(batch_size, 1),
self.config.bos_token_id,
dtype=torch.long,
device=device
)
input_ids = torch.cat([bos_tokens, input_ids], dim=1)
# Generate
outputs = self.language_model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.config.pad_token_id,
eos_token_id=self.config.eos_token_id,
**kwargs
)
# Decode only the generated part
generated_ids = outputs[:, input_ids.shape[1]:]
# Note: Actual text decoding requires a tokenizer
# Return raw generated IDs for now
return generated_ids
# Register for AutoClass
GeoMotionGPTConfig.register_for_auto_class()
GeoMotionGPTForCausalLM.register_for_auto_class("AutoModelForCausalLM")