PULSE-code / experiments /nets /published_models.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""
Published baseline models for NeurIPS 2026 benchmark experiments.
Contains faithful implementations of 6 published models:
1. DeepConvLSTM (Ordonez & Roggen, Sensors 2016) - Exp1/Exp3
2. InceptionTime (Fawaz et al., DMKD 2020) - Exp1/Exp3
3. MS-TCN++ (Li et al., TPAMI 2020) - Exp2
4. DiffAct (Liu et al., ICCV 2023) - Exp2
5. UnderPressure (Mourot et al., SCA/CGF 2022) - Exp3/Exp4a
6. emg2pose (Meta, NeurIPS 2024 D&B) - Exp4b
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# ============================================================
# 1. DeepConvLSTM (Ordonez & Roggen, Sensors 2016)
# "Deep Convolutional and LSTM Recurrent Neural Networks
# for Multimodal Wearable Activity Recognition"
# 4 Conv layers -> 2 LSTM layers -> pooling/per-frame output
# ============================================================
class DeepConvLSTMBackbone(nn.Module):
"""DeepConvLSTM backbone for sequence-level classification (Exp1).
Input: (B, T, C), optional mask
Output: (B, output_dim)
"""
def __init__(self, input_dim, hidden_dim=128, num_conv_layers=4,
conv_filters=64, conv_kernel=5, num_lstm_layers=2):
super().__init__()
conv_layers = []
in_ch = input_dim
for i in range(num_conv_layers):
out_ch = conv_filters
conv_layers.append(nn.Sequential(
nn.Conv1d(in_ch, out_ch, conv_kernel, padding=conv_kernel // 2),
nn.BatchNorm1d(out_ch),
nn.ReLU(),
nn.Dropout(0.1 if i < num_conv_layers - 1 else 0.2),
))
in_ch = out_ch
self.convs = nn.ModuleList(conv_layers)
self.lstm = nn.LSTM(
conv_filters, hidden_dim, num_layers=num_lstm_layers,
batch_first=True, bidirectional=False,
dropout=0.2 if num_lstm_layers > 1 else 0,
)
self.output_dim = hidden_dim
def forward(self, x, mask=None):
# x: (B, T, C) -> Conv expects (B, C, T)
x = x.permute(0, 2, 1)
for conv in self.convs:
x = conv(x)
x = x.permute(0, 2, 1) # (B, T, conv_filters)
out, (h_n, _) = self.lstm(x)
# Use last hidden state
feat = h_n[-1] # (B, hidden_dim)
return feat
class DeepConvLSTMContact(nn.Module):
"""DeepConvLSTM for frame-level contact detection (Exp3).
Input: (B, T, C)
Output: (B, T, 2)
"""
def __init__(self, input_dim, hidden_dim=64, num_conv_layers=4,
conv_filters=64, conv_kernel=5):
super().__init__()
conv_layers = []
in_ch = input_dim
for i in range(num_conv_layers):
conv_layers.append(nn.Sequential(
nn.Conv1d(in_ch, conv_filters, conv_kernel, padding=conv_kernel // 2),
nn.BatchNorm1d(conv_filters),
nn.ReLU(),
nn.Dropout(0.1),
))
in_ch = conv_filters
self.convs = nn.ModuleList(conv_layers)
self.lstm = nn.LSTM(conv_filters, hidden_dim, num_layers=2,
batch_first=True, bidirectional=True, dropout=0.2)
self.head = nn.Linear(hidden_dim * 2, 2)
def forward(self, x):
x = x.permute(0, 2, 1)
for conv in self.convs:
x = conv(x)
x = x.permute(0, 2, 1)
out, _ = self.lstm(x)
return self.head(out)
# ============================================================
# 2. InceptionTime (Fawaz et al., DMKD 2020)
# "InceptionTime: Finding AlexNet for Time Series Classification"
# Inception modules with multi-scale convolutions + residual
# ============================================================
class InceptionModule(nn.Module):
"""Single Inception module for time series."""
def __init__(self, in_channels, n_filters=32, kernel_sizes=(9, 19, 39),
bottleneck_channels=32):
super().__init__()
# Bottleneck
self.bottleneck = nn.Conv1d(in_channels, bottleneck_channels, 1, bias=False)
# Parallel convolutions with different kernel sizes (odd kernels for symmetric padding)
self.convs = nn.ModuleList()
for ks in kernel_sizes:
self.convs.append(
nn.Conv1d(bottleneck_channels, n_filters, ks,
padding=(ks - 1) // 2, bias=False)
)
# MaxPool branch
self.maxpool_conv = nn.Sequential(
nn.MaxPool1d(3, stride=1, padding=1),
nn.Conv1d(in_channels, n_filters, 1, bias=False),
)
self.bn = nn.BatchNorm1d(n_filters * (len(kernel_sizes) + 1))
self.relu = nn.ReLU()
def forward(self, x):
# x: (B, C, T)
x_bottleneck = self.bottleneck(x)
conv_outputs = [conv(x_bottleneck) for conv in self.convs]
conv_outputs.append(self.maxpool_conv(x))
out = torch.cat(conv_outputs, dim=1)
return self.relu(self.bn(out))
class InceptionBlock(nn.Module):
"""Stack of Inception modules with a residual connection."""
def __init__(self, in_channels, n_filters=32, depth=3):
super().__init__()
n_out = n_filters * 4 # 3 conv branches + 1 maxpool branch
modules = []
for i in range(depth):
inc = in_channels if i == 0 else n_out
modules.append(InceptionModule(inc, n_filters))
self.modules_list = nn.ModuleList(modules)
# Residual connection
self.use_residual = (in_channels != n_out)
if self.use_residual:
self.residual = nn.Sequential(
nn.Conv1d(in_channels, n_out, 1, bias=False),
nn.BatchNorm1d(n_out),
)
self.relu = nn.ReLU()
def forward(self, x):
residual = x
for mod in self.modules_list:
x = mod(x)
if self.use_residual:
residual = self.residual(residual)
return self.relu(x + residual)
class InceptionTimeBackbone(nn.Module):
"""InceptionTime backbone for sequence-level classification (Exp1).
Input: (B, T, C), optional mask
Output: (B, output_dim)
"""
def __init__(self, input_dim, hidden_dim=128, n_filters=32, num_blocks=2, depth=3):
super().__init__()
blocks = []
in_ch = input_dim
for i in range(num_blocks):
blocks.append(InceptionBlock(in_ch, n_filters, depth))
in_ch = n_filters * 4
self.blocks = nn.ModuleList(blocks)
self.output_dim = n_filters * 4
def forward(self, x, mask=None):
# x: (B, T, C) -> (B, C, T)
x = x.permute(0, 2, 1)
for block in self.blocks:
x = block(x)
# Global average pooling with mask
if mask is not None:
x = (x * mask.unsqueeze(1).float()).sum(2) / mask.sum(1, keepdim=True).float().clamp(min=1)
else:
x = x.mean(2)
return x # (B, n_filters*4)
class InceptionTimeContact(nn.Module):
"""InceptionTime for frame-level contact detection (Exp3).
Input: (B, T, C)
Output: (B, T, 2)
"""
def __init__(self, input_dim, hidden_dim=64, n_filters=32, num_blocks=2, depth=3):
super().__init__()
blocks = []
in_ch = input_dim
for i in range(num_blocks):
blocks.append(InceptionBlock(in_ch, n_filters, depth))
in_ch = n_filters * 4
self.blocks = nn.ModuleList(blocks)
self.head = nn.Conv1d(n_filters * 4, 2, 1)
def forward(self, x):
x = x.permute(0, 2, 1)
for block in self.blocks:
x = block(x)
out = self.head(x)
return out.permute(0, 2, 1) # (B, T, 2)
# ============================================================
# 3. MS-TCN++ (Li et al., TPAMI 2020)
# "MS-TCN++: Multi-Stage Temporal Convolutional Network
# for Action Segmentation"
# Key improvement: dual dilated layers in each residual block
# ============================================================
class DualDilatedResBlock(nn.Module):
"""Dual dilated residual block (MS-TCN++ key contribution).
Uses two parallel dilated convolutions with different dilation rates
to capture both short-range and long-range temporal patterns.
"""
def __init__(self, channels, dilation1, dilation2):
super().__init__()
# Branch 1: smaller dilation
self.conv1_dilated = nn.Conv1d(
channels, channels, 3,
padding=dilation1, dilation=dilation1
)
# Branch 2: larger dilation
self.conv2_dilated = nn.Conv1d(
channels, channels, 3,
padding=dilation2, dilation=dilation2
)
self.conv_fusion = nn.Conv1d(channels, channels, 1)
self.bn = nn.BatchNorm1d(channels)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
residual = x
out1 = F.relu(self.conv1_dilated(x))
out2 = F.relu(self.conv2_dilated(x))
out = out1 + out2
out = self.dropout(F.relu(self.bn(self.conv_fusion(out))))
return out + residual
class MSTCNPPStage(nn.Module):
"""Single stage of MS-TCN++ with dual dilated layers."""
def __init__(self, in_channels, hidden_channels, num_classes, num_layers=10):
super().__init__()
self.input_conv = nn.Conv1d(in_channels, hidden_channels, 1)
self.layers = nn.ModuleList()
for i in range(num_layers):
dilation1 = 2 ** i
dilation2 = 2 ** (i + 1) if i < num_layers - 1 else 2 ** i
self.layers.append(DualDilatedResBlock(hidden_channels, dilation1, dilation2))
self.output_conv = nn.Conv1d(hidden_channels, num_classes, 1)
def forward(self, x):
x = self.input_conv(x)
for layer in self.layers:
x = layer(x)
return self.output_conv(x)
class MSTCNPP(nn.Module):
"""MS-TCN++ for temporal action segmentation (Exp2).
Input: (B, T, C)
Output: list of (B, T, num_classes) per stage
"""
def __init__(self, input_dim, num_classes, hidden_dim=64, num_stages=4, num_layers=10):
super().__init__()
self.stages = nn.ModuleList()
# First stage: input features -> predictions
self.stages.append(MSTCNPPStage(input_dim, hidden_dim, num_classes, num_layers))
# Refinement stages: predictions -> refined predictions
for _ in range(num_stages - 1):
self.stages.append(MSTCNPPStage(num_classes, hidden_dim, num_classes, num_layers))
def forward(self, x):
x = x.permute(0, 2, 1) # (B, C, T)
outputs = []
for stage in self.stages:
x = stage(x)
outputs.append(x.permute(0, 2, 1)) # (B, T, num_classes)
# Feed softmax of predictions to next stage
if stage != self.stages[-1]:
x = F.softmax(x, dim=1)
return outputs
# ============================================================
# 4. DiffAct (Liu et al., ICCV 2023)
# "Diffusion Action Segmentation"
# Denoising diffusion model for iterative action refinement.
# Simplified but faithful implementation.
# ============================================================
class ConditionalLayerNorm(nn.Module):
"""Layer norm conditioned on diffusion timestep."""
def __init__(self, channels):
super().__init__()
self.norm = nn.GroupNorm(1, channels) # equivalent to LayerNorm for 1D
def forward(self, x):
return self.norm(x)
class DiffActBlock(nn.Module):
"""Residual block for DiffAct denoising network."""
def __init__(self, channels, dilation, time_emb_dim):
super().__init__()
self.conv1 = nn.Conv1d(channels, channels, 3, padding=dilation, dilation=dilation)
self.conv2 = nn.Conv1d(channels, channels, 1)
self.norm1 = ConditionalLayerNorm(channels)
self.norm2 = ConditionalLayerNorm(channels)
self.time_proj = nn.Linear(time_emb_dim, channels)
self.dropout = nn.Dropout(0.1)
def forward(self, x, time_emb):
residual = x
x = self.norm1(x)
x = F.relu(self.conv1(x))
# Add time embedding
t = self.time_proj(time_emb).unsqueeze(-1) # (B, C, 1)
x = x + t
x = self.norm2(x)
x = self.dropout(F.relu(self.conv2(x)))
return x + residual
class DiffActConditionEncoder(nn.Module):
"""Temporal feature encoder for conditioning the denoising network."""
def __init__(self, input_dim, hidden_dim, num_layers=6):
super().__init__()
self.input_conv = nn.Conv1d(input_dim, hidden_dim, 1)
self.layers = nn.ModuleList()
for i in range(num_layers):
dilation = 2 ** (i % 5)
self.layers.append(nn.Sequential(
nn.Conv1d(hidden_dim, hidden_dim, 3, padding=dilation, dilation=dilation),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
))
def forward(self, x):
x = self.input_conv(x)
for layer in self.layers:
x = layer(x) + x # residual
return x
class SinusoidalTimeEmbedding(nn.Module):
"""Sinusoidal positional embedding for diffusion timestep."""
def __init__(self, dim):
super().__init__()
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
def forward(self, t):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t.unsqueeze(-1).float() * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return self.mlp(emb)
class DiffAct(nn.Module):
"""DiffAct: Diffusion Action Segmentation (Exp2).
During training: noises ground-truth action probabilities and denoises.
During inference: iteratively denoises from pure noise.
Input: (B, T, C)
Output: list of (B, T, num_classes) [final denoised prediction]
"""
def __init__(self, input_dim, num_classes, hidden_dim=64,
num_encoder_layers=6, num_denoise_layers=6,
num_diffusion_steps=10):
super().__init__()
self.num_classes = num_classes
self.num_steps = num_diffusion_steps
# Condition encoder: extract temporal features from input
self.condition_encoder = DiffActConditionEncoder(input_dim, hidden_dim, num_encoder_layers)
# Initial prediction head (non-diffusion baseline)
self.initial_head = nn.Conv1d(hidden_dim, num_classes, 1)
# Time embedding
self.time_emb = SinusoidalTimeEmbedding(hidden_dim)
# Denoising network
self.denoise_input = nn.Conv1d(num_classes + hidden_dim, hidden_dim, 1)
self.denoise_blocks = nn.ModuleList()
for i in range(num_denoise_layers):
dilation = 2 ** (i % 5)
self.denoise_blocks.append(DiffActBlock(hidden_dim, dilation, hidden_dim))
self.denoise_output = nn.Conv1d(hidden_dim, num_classes, 1)
# Noise schedule (cosine)
self._setup_noise_schedule()
def _setup_noise_schedule(self):
steps = self.num_steps
s = 0.008
t = torch.linspace(0, steps, steps + 1)
alphas_cumprod = torch.cos(((t / steps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas = torch.clamp(betas, 0.0001, 0.999)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod))
def _add_noise(self, x_start, t, noise=None):
"""Add noise to x_start at timestep t."""
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1)
sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)
return sqrt_alpha * x_start + sqrt_one_minus * noise
def _denoise_step(self, x_noisy, cond_features, time_emb):
"""Single denoising step."""
x = torch.cat([x_noisy, cond_features], dim=1) # (B, C+hidden, T)
x = self.denoise_input(x)
for block in self.denoise_blocks:
x = block(x, time_emb)
return self.denoise_output(x)
def forward(self, x):
"""
Training: returns [initial_pred, denoised_pred]
Inference: returns [initial_pred, iteratively_denoised_pred]
"""
x_in = x.permute(0, 2, 1) # (B, C, T)
B, _, T = x_in.shape
# Encode condition features
cond = self.condition_encoder(x_in) # (B, hidden, T)
initial_logits = self.initial_head(cond).permute(0, 2, 1) # (B, T, num_classes)
if self.training:
# Training: noise the initial prediction and denoise (end-to-end)
x_start = F.softmax(initial_logits, dim=-1).permute(0, 2, 1) # (B, C, T)
t = torch.randint(0, self.num_steps, (B,), device=x.device)
noise = torch.randn_like(x_start)
x_noisy = self._add_noise(x_start.detach(), t, noise)
time_emb = self.time_emb(t)
denoised = self._denoise_step(x_noisy, cond, time_emb)
return [initial_logits, denoised.permute(0, 2, 1)]
else:
# Inference: iterative denoising from noise
x_t = torch.randn(B, self.num_classes, T, device=x.device)
for step in reversed(range(self.num_steps)):
t = torch.full((B,), step, device=x.device, dtype=torch.long)
time_emb = self.time_emb(t)
pred_noise = self._denoise_step(x_t, cond, time_emb)
# Simplified DDPM update
alpha = self.alphas_cumprod[step]
alpha_prev = self.alphas_cumprod[step - 1] if step > 0 else torch.tensor(1.0)
beta = self.betas[step]
x_t = (1 / torch.sqrt(1 - beta)) * (
x_t - beta / self.sqrt_one_minus_alphas_cumprod[step] * pred_noise
)
if step > 0:
x_t = x_t + torch.sqrt(beta) * torch.randn_like(x_t) * 0.5
return [initial_logits, x_t.permute(0, 2, 1)]
# ============================================================
# 5. UnderPressure (Mourot et al., SCA/CGF 2022)
# "UnderPressure: Deep Learning for Foot Contact Detection,
# Ground Reaction Force Estimation and Footskate Cleanup"
# GRU-based architecture for contact detection + force regression.
# Adapted for hand contact detection and MoCap->Pressure prediction.
# ============================================================
class UnderPressureContact(nn.Module):
"""UnderPressure model adapted for hand contact detection (Exp3).
Architecture: Conv feature extractor -> BiGRU -> contact prediction head
Input: (B, T, C)
Output: (B, T, 2) [right_contact, left_contact]
"""
def __init__(self, input_dim, hidden_dim=64, num_gru_layers=2):
super().__init__()
# Feature extractor (conv layers for local temporal patterns)
self.feature_extractor = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim, 7, padding=3),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
)
# BiGRU for temporal modeling
self.gru = nn.GRU(
hidden_dim, hidden_dim, num_layers=num_gru_layers,
batch_first=True, bidirectional=True,
dropout=0.2 if num_gru_layers > 1 else 0,
)
# Contact prediction head
self.contact_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, 2),
)
def forward(self, x):
# x: (B, T, C) -> (B, C, T)
feat = self.feature_extractor(x.permute(0, 2, 1))
feat = feat.permute(0, 2, 1) # (B, T, hidden)
gru_out, _ = self.gru(feat)
return self.contact_head(gru_out) # (B, T, 2)
class UnderPressureRegressor(nn.Module):
"""UnderPressure model adapted for MoCap -> Pressure regression (Exp4a).
Architecture: Conv feature extractor -> BiGRU -> pressure regression head
Input: (B, T, input_dim)
Output: (B, T, output_dim)
"""
def __init__(self, input_dim, output_dim, hidden_dim=128, num_gru_layers=2):
super().__init__()
self.feature_extractor = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim, 7, padding=3),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
)
self.gru = nn.GRU(
hidden_dim, hidden_dim, num_layers=num_gru_layers,
batch_first=True, bidirectional=True,
dropout=0.2 if num_gru_layers > 1 else 0,
)
self.regression_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
feat = self.feature_extractor(x.permute(0, 2, 1))
feat = feat.permute(0, 2, 1)
gru_out, _ = self.gru(feat)
return self.regression_head(gru_out)
# ============================================================
# 6. emg2pose (Meta/Facebook Research, NeurIPS 2024 D&B)
# "emg2pose: A Large and Diverse Benchmark for
# Surface Electromyographic Hand Pose Estimation"
# CNN feature extractor + Transformer encoder,
# with optional velocity-based integration (vemg2pose).
# ============================================================
class EMG2PoseEncoder(nn.Module):
"""CNN + Transformer encoder from emg2pose."""
def __init__(self, input_dim, hidden_dim=128, num_transformer_layers=4, nhead=4):
super().__init__()
# Multi-scale CNN feature extractor
self.conv_small = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim // 2, 3, padding=1),
nn.BatchNorm1d(hidden_dim // 2),
nn.ReLU(),
)
self.conv_medium = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim // 4, 7, padding=3),
nn.BatchNorm1d(hidden_dim // 4),
nn.ReLU(),
)
self.conv_large = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim // 4, 15, padding=7),
nn.BatchNorm1d(hidden_dim // 4),
nn.ReLU(),
)
# Projection to hidden_dim
self.proj = nn.Sequential(
nn.Conv1d(hidden_dim, hidden_dim, 1),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
)
# Transformer encoder for temporal modeling
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=nhead,
dim_feedforward=hidden_dim * 4,
dropout=0.1, batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_transformer_layers)
def forward(self, x):
# x: (B, T, C) -> (B, C, T)
x_t = x.permute(0, 2, 1)
f_small = self.conv_small(x_t)
f_medium = self.conv_medium(x_t)
f_large = self.conv_large(x_t)
feat = torch.cat([f_small, f_medium, f_large], dim=1)
feat = self.proj(feat).permute(0, 2, 1) # (B, T, hidden)
return self.transformer(feat)
class EMG2Pose(nn.Module):
"""emg2pose model for EMG -> Hand Pose regression (Exp4b).
Predicts per-frame hand joint positions from EMG signals.
Uses velocity-based integration (vemg2pose variant):
predict velocity -> integrate to get positions.
Input: (B, T, input_dim) [EMG channels]
Output: (B, T, output_dim) [hand joint positions]
"""
def __init__(self, input_dim, output_dim, hidden_dim=128,
num_transformer_layers=4, use_velocity=True):
super().__init__()
self.use_velocity = use_velocity
self.encoder = EMG2PoseEncoder(input_dim, hidden_dim, num_transformer_layers)
if use_velocity:
# Predict velocity, then integrate
self.velocity_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 2, output_dim),
)
# Learnable initial position
self.initial_pos = nn.Parameter(torch.zeros(1, 1, output_dim))
else:
# Direct position prediction
self.position_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 2, output_dim),
)
def forward(self, x):
features = self.encoder(x) # (B, T, hidden)
if self.use_velocity:
velocity = self.velocity_head(features) # (B, T, output_dim)
# Cumulative sum to integrate velocity -> position
positions = torch.cumsum(velocity, dim=1) + self.initial_pos
return positions
else:
return self.position_head(features)