""" MARS: Multi-scale Adaptive Recurrence with State compression ============================================================ An innovative method for super long sequence modeling in sequential recommendation. Key innovations: 1. Temporal-Aware Delta Network (TADN) for O(n) long-range modeling - Explicit exponential temporal decay in state updates - Input-dependent gating for selective memory retention 2. Compressive Memory Tokens - Fixed-size learnable memory that compresses arbitrarily long histories - Acts as information bottleneck (denoising effect per Rec2PM) 3. Dual-Branch Architecture with Learned Fusion - Long-term branch: TADN layers processing full history at O(n) cost - Short-term branch: Standard self-attention on recent K interactions - Adaptive gating fusion that balances long/short-term signals per user 4. Multi-Scale Temporal Encoding - Absolute time embeddings + relative time deltas + periodic components - Captures daily/weekly/seasonal patterns in user behavior This combines ideas from HyTRec (2602.18283), Rec2PM (2602.11605), SIGMA (2408.11451), and HSTU (2402.17152) into a unified architecture. """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Dict class TemporalEncoding(nn.Module): """Multi-scale temporal encoding with periodic components. Captures absolute time, relative time deltas, and periodic patterns (daily, weekly cycles) in user behavior. """ def __init__(self, embed_dim: int, max_periods: int = 4): super().__init__() self.embed_dim = embed_dim # Relative time delta projection self.time_delta_proj = nn.Linear(1, embed_dim) # Periodic components (daily=86400s, weekly=604800s, etc.) periods = [3600, 86400, 604800, 2592000][:max_periods] self.register_buffer('periods', torch.tensor(periods, dtype=torch.float32)) self.periodic_proj = nn.Linear(max_periods * 2, embed_dim) # sin + cos # Learnable position encoding for sequence order self.layernorm = nn.LayerNorm(embed_dim) def forward(self, timestamps: torch.Tensor) -> torch.Tensor: """ Args: timestamps: (batch, seq_len) absolute timestamps in seconds Returns: temporal_emb: (batch, seq_len, embed_dim) """ B, T = timestamps.shape # 1. Relative time deltas (seconds since previous interaction) time_deltas = torch.zeros_like(timestamps) time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1] time_deltas = time_deltas.clamp(min=0) # Log-scale for better numerical properties log_deltas = torch.log1p(time_deltas).unsqueeze(-1) # (B, T, 1) delta_emb = self.time_delta_proj(log_deltas) # (B, T, D) # 2. Periodic components ts_expanded = timestamps.unsqueeze(-1) # (B, T, 1) periods = self.periods.view(1, 1, -1) # (1, 1, P) angles = 2 * math.pi * ts_expanded / periods # (B, T, P) periodic_features = torch.cat([ torch.sin(angles), torch.cos(angles) ], dim=-1) # (B, T, 2*P) periodic_emb = self.periodic_proj(periodic_features) # (B, T, D) # 3. Combine temporal_emb = self.layernorm(delta_emb + periodic_emb) return temporal_emb class TADNLayer(nn.Module): """Temporal-Aware Delta Network Layer. Linear complexity O(n) recurrent layer with: - Delta rule state updates (inspired by HyTRec) - Explicit temporal decay gating - Input-dependent selective memory The state matrix S is updated as: S_t = S_{t-1} * (I - g_t * beta_t * k_t * k_t^T) + beta_t * v_t * k_t^T where g_t incorporates temporal decay: g_t = alpha * sigmoid(W_g * [h_t, delta_h_t]) * tau_t + (1-alpha) * g_static tau_t = exp(-(t_current - t_behavior) / T) """ def __init__(self, embed_dim: int, state_dim: int = 64, dropout: float = 0.1): super().__init__() self.embed_dim = embed_dim self.state_dim = state_dim # Query, Key, Value projections self.q_proj = nn.Linear(embed_dim, state_dim) self.k_proj = nn.Linear(embed_dim, state_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) # Gating mechanism self.gate_proj = nn.Linear(embed_dim * 2, embed_dim) self.beta_proj = nn.Linear(embed_dim, state_dim) # Temporal decay parameters self.alpha = nn.Parameter(torch.tensor(0.5)) self.time_scale = nn.Parameter(torch.tensor(1.0)) # Static gate (learnable baseline) self.gate_static = nn.Parameter(torch.ones(embed_dim) * 0.5) # Output self.out_proj = nn.Linear(embed_dim, embed_dim) self.layernorm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward( self, x: torch.Tensor, timestamps: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: x: (batch, seq_len, embed_dim) input sequence timestamps: (batch, seq_len) timestamps for temporal decay mask: (batch, seq_len) boolean mask (True = valid) Returns: output: (batch, seq_len, embed_dim) """ B, T, D = x.shape # Project to Q, K, V q = self.q_proj(x) # (B, T, state_dim) k = self.k_proj(x) # (B, T, state_dim) v = self.v_proj(x) # (B, T, D) # Beta (key importance scaling) beta = torch.sigmoid(self.beta_proj(x)) # (B, T, state_dim) # Temporal decay if timestamps is not None: # Compute recency-based decay with proper normalization # Use the LAST VALID position's timestamp as reference # Normalize by log(1 + delta) to handle large time ranges (seconds → years) t_last = timestamps[:, -1:].unsqueeze(-1) # (B, 1, 1) - last timestamp t_behavior = timestamps.unsqueeze(-1) # (B, T, 1) time_delta = (t_last - t_behavior).clamp(min=0) # Log-normalize: log(1 + delta_seconds / 3600) → hours-scale log_delta = torch.log1p(time_delta / 3600.0) # Normalize to hours # Learnable time scale controls the decay rate tau = torch.exp( -log_delta / (torch.abs(self.time_scale) * 10.0 + 1.0) ) # (B, T, 1), values in [0, 1] else: # Fallback: linear decay positions = torch.arange(T, device=x.device).float() tau = torch.exp(-positions / (T + 1e-6)).view(1, T, 1) # Dynamic gating with temporal awareness # Delta of hidden states for change detection x_shifted = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) delta_x = x - x_shifted gate_input = torch.cat([x, delta_x], dim=-1) # (B, T, 2*D) alpha = torch.sigmoid(self.alpha) g_dynamic = torch.sigmoid(self.gate_proj(gate_input)) # (B, T, D) g = alpha * g_dynamic * tau + (1 - alpha) * torch.sigmoid(self.gate_static) # Recurrent state update with delta rule # Use chunked processing for better GPU utilization chunk_size = min(64, T) # Process in chunks for efficiency outputs = [] S = torch.zeros(B, self.state_dim, D, device=x.device) # State matrix for chunk_start in range(0, T, chunk_size): chunk_end = min(chunk_start + chunk_size, T) for t in range(chunk_start, chunk_end): k_t = k[:, t] # (B, state_dim) v_t = v[:, t] # (B, D) beta_t = beta[:, t] # (B, state_dim) g_t = g[:, t] # (B, D) q_t = q[:, t] # (B, state_dim) # Delta rule update: erase old, write new # Clamp erase to [0, 1] for stability erase = torch.einsum('bs,bd->bsd', beta_t * k_t, g_t).clamp(0, 1) write = torch.einsum('bs,bd->bsd', beta_t * k_t, v_t) if mask is not None: valid = mask[:, t].float().view(B, 1, 1) S = S * (1 - erase * valid) + write * valid else: S = S * (1 - erase) + write # Clamp state for numerical stability S = S.clamp(-10, 10) # Read from state out_t = torch.einsum('bs,bsd->bd', q_t, S) outputs.append(out_t) output = torch.stack(outputs, dim=1) # (B, T, D) output = self.out_proj(self.dropout(output)) output = self.layernorm(x + output) # Residual connection return output class CompressiveMemory(nn.Module): """Compressive Memory Module. Compresses long sequence history into a fixed number of memory tokens. Acts as information bottleneck (denoising per Rec2PM theory). Uses cross-attention: memory queries attend to sequence to extract summary. """ def __init__(self, embed_dim: int, num_memory_tokens: int = 8, num_heads: int = 2): super().__init__() self.num_memory_tokens = num_memory_tokens # Learnable memory query tokens self.memory_queries = nn.Parameter( torch.randn(num_memory_tokens, embed_dim) * 0.02 ) # Cross-attention: memory queries attend to sequence self.cross_attn = nn.MultiheadAttention( embed_dim=embed_dim, num_heads=num_heads, batch_first=True, dropout=0.1 ) self.ffn = nn.Sequential( nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(0.1), nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(0.1), ) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) def forward( self, sequence: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: sequence: (batch, seq_len, embed_dim) - encoded sequence mask: (batch, seq_len) boolean mask (True = valid, False = padding) Returns: memory: (batch, num_memory_tokens, embed_dim) """ B = sequence.shape[0] # Expand memory queries for batch queries = self.memory_queries.unsqueeze(0).expand(B, -1, -1) # (B, M, D) # Cross-attention with key padding mask # nn.MultiheadAttention expects key_padding_mask where True = IGNORE if mask is not None: key_padding_mask = ~mask # Invert: True means padding (to ignore) else: key_padding_mask = None attn_out, _ = self.cross_attn( query=queries, key=sequence, value=sequence, key_padding_mask=key_padding_mask ) memory = self.norm1(queries + attn_out) memory = self.norm2(memory + self.ffn(memory)) return memory class ShortTermAttention(nn.Module): """Standard self-attention block for short-term (recent) interactions. Uses standard causal multi-head attention — full expressiveness for the most recent K items where O(K²) is acceptable. """ def __init__(self, embed_dim: int, num_heads: int = 2, num_layers: int = 2, dropout: float = 0.1): super().__init__() encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=dropout, activation='gelu', batch_first=True, norm_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: x: (batch, K, embed_dim) recent interactions mask: (batch, K) boolean mask Returns: output: (batch, K, embed_dim) """ T = x.shape[1] # Causal mask causal_mask = torch.triu( torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1 ) # Padding mask src_key_padding_mask = ~mask if mask is not None else None output = self.encoder( x, mask=causal_mask, src_key_padding_mask=src_key_padding_mask ) return output class AdaptiveFusionGate(nn.Module): """Adaptive fusion gate that balances long-term and short-term signals. Per-user, per-timestep gating: output = sigma(gate) * long_term + (1 - sigma(gate)) * short_term """ def __init__(self, embed_dim: int): super().__init__() self.gate = nn.Sequential( nn.Linear(embed_dim * 3, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim), nn.Sigmoid() ) def forward( self, long_term: torch.Tensor, short_term: torch.Tensor, memory: torch.Tensor ) -> torch.Tensor: """ Args: long_term: (batch, embed_dim) short_term: (batch, embed_dim) memory: (batch, embed_dim) compressed memory summary Returns: fused: (batch, embed_dim) """ gate_input = torch.cat([long_term, short_term, memory], dim=-1) g = self.gate(gate_input) return g * long_term + (1 - g) * short_term class MARS(nn.Module): """ MARS: Multi-scale Adaptive Recurrence with State compression Architecture: Input: Full user interaction sequence + timestamps | v [Item Embedding + Temporal Encoding] | +---- Long-term Branch (TADN layers, O(n)) | | | [Compressive Memory] → memory tokens | | +---- Short-term Branch (Self-Attention on recent K items) | v [Adaptive Fusion Gate] | v [Prediction Head] → next item scores Args: num_items: number of unique items embed_dim: embedding dimension max_seq_len: maximum sequence length (can be very long, e.g. 2048) short_term_len: number of recent items for short-term branch num_memory_tokens: number of compressive memory tokens num_tadn_layers: number of TADN layers in long-term branch num_attn_layers: number of attention layers in short-term branch num_heads: number of attention heads state_dim: state dimension for TADN dropout: dropout rate """ def __init__( self, num_items: int, embed_dim: int = 64, max_seq_len: int = 512, short_term_len: int = 50, num_memory_tokens: int = 8, num_tadn_layers: int = 3, num_attn_layers: int = 2, num_heads: int = 2, state_dim: int = 64, dropout: float = 0.1, ): super().__init__() self.num_items = num_items self.embed_dim = embed_dim self.max_seq_len = max_seq_len self.short_term_len = short_term_len self.num_memory_tokens = num_memory_tokens # Item embeddings (0 = padding) self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0) # Temporal encoding self.temporal_encoding = TemporalEncoding(embed_dim) # Learnable position encoding (for short-term branch) self.position_embedding = nn.Embedding(max_seq_len, embed_dim) # Input processing self.input_norm = nn.LayerNorm(embed_dim) self.input_dropout = nn.Dropout(dropout) # Long-term branch: stack of TADN layers self.tadn_layers = nn.ModuleList([ TADNLayer(embed_dim, state_dim, dropout) for _ in range(num_tadn_layers) ]) # Compressive memory self.compressive_memory = CompressiveMemory( embed_dim, num_memory_tokens, num_heads ) # Short-term branch: standard self-attention self.short_term_attn = ShortTermAttention( embed_dim, num_heads, num_attn_layers, dropout ) # Adaptive fusion self.fusion_gate = AdaptiveFusionGate(embed_dim) # Output projection self.output_norm = nn.LayerNorm(embed_dim) self.output_proj = nn.Linear(embed_dim, embed_dim) # Initialize weights self._init_weights() def _init_weights(self): """Initialize with truncated normal distribution.""" for name, param in self.named_parameters(): if 'weight' in name and param.dim() >= 2: nn.init.trunc_normal_(param, std=0.02) elif 'bias' in name: nn.init.zeros_(param) # Special init for item embeddings nn.init.trunc_normal_(self.item_embedding.weight, std=0.02) nn.init.zeros_(self.item_embedding.weight[0]) # Padding = zero @property def item_embeddings(self): """Access item embedding table (for evaluation).""" return self.item_embedding def encode( self, item_ids: torch.Tensor, timestamps: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Encode a full sequence into user representations. Args: item_ids: (batch, seq_len) item indices (0 = padding) timestamps: (batch, seq_len) timestamps in seconds mask: (batch, seq_len) boolean mask (True = valid) Returns: user_emb: (batch, embed_dim) final user representation """ B, T = item_ids.shape # Create mask from padding if not provided if mask is None: mask = (item_ids != 0) # 1. Item + Temporal Embeddings item_emb = self.item_embedding(item_ids) # (B, T, D) if timestamps is not None: temp_emb = self.temporal_encoding(timestamps.float()) item_emb = item_emb + temp_emb # Add position embeddings (only for the sequence order) positions = torch.arange(T, device=item_ids.device).unsqueeze(0) positions = positions.clamp(max=self.max_seq_len - 1) pos_emb = self.position_embedding(positions) item_emb = self.input_norm(item_emb + pos_emb) item_emb = self.input_dropout(item_emb) # 2. Long-term Branch: TADN over full sequence long_term_repr = item_emb for tadn in self.tadn_layers: long_term_repr = tadn(long_term_repr, timestamps, mask) # Compress long-term into memory tokens memory = self.compressive_memory(long_term_repr, mask) # (B, M, D) memory_summary = memory.mean(dim=1) # (B, D) - aggregated memory # Get last valid long-term representation # Use mask to find last valid position lengths = mask.sum(dim=1).long() # (B,) long_term_last = long_term_repr[ torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0) ] # (B, D) # 3. Short-term Branch: Attention on recent K items # With right-padding, valid items are at positions 0...(length-1) # Extract last K valid items per user K = min(self.short_term_len, T) # For each user, get the last K valid positions short_item_ids_list = [] short_ts_list = [] short_mask_list = [] for b in range(B): seq_len = lengths[b].item() actual_k = min(K, seq_len) start = max(0, seq_len - K) end = seq_len # Extract valid items and pad to K ids = item_ids[b, start:end] pad_len = K - actual_k if pad_len > 0: ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype, device=ids.device)]) short_item_ids_list.append(ids) if timestamps is not None: ts = timestamps[b, start:end] if pad_len > 0: ts = torch.cat([ts, torch.zeros(pad_len, dtype=ts.dtype, device=ts.device)]) short_ts_list.append(ts) m = torch.zeros(K, dtype=torch.bool, device=item_ids.device) m[:actual_k] = True short_mask_list.append(m) short_item_ids = torch.stack(short_item_ids_list) # (B, K) short_mask = torch.stack(short_mask_list) # (B, K) short_emb = self.item_embedding(short_item_ids) if timestamps is not None: short_ts = torch.stack(short_ts_list) # (B, K) short_temp = self.temporal_encoding(short_ts.float()) short_emb = short_emb + short_temp short_positions = torch.arange(K, device=item_ids.device).unsqueeze(0) short_positions = short_positions.clamp(max=self.max_seq_len - 1) short_emb = short_emb + self.position_embedding(short_positions) short_emb = self.input_norm(short_emb) short_term_repr = self.short_term_attn(short_emb, short_mask) # Get last valid short-term representation short_lengths = short_mask.sum(dim=1).long() short_term_last = short_term_repr[ torch.arange(B, device=item_ids.device), (short_lengths - 1).clamp(min=0) ] # (B, D) # 4. Adaptive Fusion user_emb = self.fusion_gate(long_term_last, short_term_last, memory_summary) user_emb = self.output_proj(self.output_norm(user_emb)) return user_emb def forward( self, batch: Dict[str, torch.Tensor] ) -> torch.Tensor: """ Training forward pass. Expected batch format (flat tensors, matching Yambda convention): - item_ids: (batch, max_seq_len) padded item sequences - timestamps: (batch, max_seq_len) padded timestamps - mask: (batch, max_seq_len) boolean mask - positive_ids: (batch,) positive next items - negative_ids: (batch, num_neg) negative items Returns: loss: scalar BCE loss """ if self.training: return self._training_forward(batch) else: return self._eval_forward(batch) def _training_forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Compute training loss with next-item prediction.""" item_ids = batch['item_ids'] # (B, T) timestamps = batch.get('timestamps') # (B, T) or None mask = batch.get('mask') # (B, T) pos_ids = batch['positive_ids'] # (B,) neg_ids = batch['negative_ids'] # (B, num_neg) # Encode user sequence user_emb = self.encode(item_ids, timestamps, mask) # (B, D) # Score positive and negative items pos_emb = self.item_embedding(pos_ids) # (B, D) neg_emb = self.item_embedding(neg_ids) # (B, num_neg, D) pos_scores = (user_emb * pos_emb).sum(dim=-1) # (B,) neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) # (B, num_neg) # BPR-style loss + BCE pos_labels = torch.ones_like(pos_scores) neg_labels = torch.zeros_like(neg_scores) loss_pos = F.binary_cross_entropy_with_logits(pos_scores, pos_labels) loss_neg = F.binary_cross_entropy_with_logits(neg_scores, neg_labels) loss = loss_pos + loss_neg return loss def _eval_forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Eval forward: returns user embeddings.""" item_ids = batch['item_ids'] timestamps = batch.get('timestamps') mask = batch.get('mask') user_emb = self.encode(item_ids, timestamps, mask) return user_emb class SASRecBaseline(nn.Module): """ Standard SASRec baseline for comparison. Uses causal self-attention (O(n²) complexity). """ def __init__( self, num_items: int, embed_dim: int = 64, max_seq_len: int = 200, num_heads: int = 2, num_layers: int = 2, dropout: float = 0.1, ): super().__init__() self.num_items = num_items self.embed_dim = embed_dim self.max_seq_len = max_seq_len self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0) self.position_embedding = nn.Embedding(max_seq_len, embed_dim) self.input_norm = nn.LayerNorm(embed_dim) self.input_dropout = nn.Dropout(dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=dropout, activation='gelu', batch_first=True, norm_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.output_norm = nn.LayerNorm(embed_dim) self._init_weights() def _init_weights(self): for name, param in self.named_parameters(): if 'weight' in name and param.dim() >= 2: nn.init.trunc_normal_(param, std=0.02) elif 'bias' in name: nn.init.zeros_(param) nn.init.zeros_(self.item_embedding.weight[0]) @property def item_embeddings(self): return self.item_embedding def encode(self, item_ids, timestamps=None, mask=None): B, T = item_ids.shape if mask is None: mask = (item_ids != 0) item_emb = self.item_embedding(item_ids) positions = torch.arange(T, device=item_ids.device).unsqueeze(0) positions = positions.clamp(max=self.max_seq_len - 1) item_emb = item_emb + self.position_embedding(positions) item_emb = self.input_norm(item_emb) item_emb = self.input_dropout(item_emb) causal_mask = torch.triu(torch.ones(T, T, device=item_ids.device, dtype=torch.bool), diagonal=1) src_key_padding_mask = ~mask output = self.encoder(item_emb, mask=causal_mask, src_key_padding_mask=src_key_padding_mask) lengths = mask.sum(dim=1).long() user_emb = output[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)] user_emb = self.output_norm(user_emb) return user_emb def forward(self, batch): if self.training: item_ids = batch['item_ids'] timestamps = batch.get('timestamps') mask = batch.get('mask') pos_ids = batch['positive_ids'] neg_ids = batch['negative_ids'] user_emb = self.encode(item_ids, timestamps, mask) pos_emb = self.item_embedding(pos_ids) neg_emb = self.item_embedding(neg_ids) pos_scores = (user_emb * pos_emb).sum(dim=-1) neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores)) loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores)) return loss_pos + loss_neg else: return self.encode(batch['item_ids'], batch.get('timestamps'), batch.get('mask'))