""" MARS v2: Simplified and stabilized architecture. Key changes from v1: 1. Replace unstable delta-rule state with temporal-gated linear attention 2. Simpler but more robust long-term branch 3. FFN layers for capacity """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict class TemporalEncoding(nn.Module): """Multi-scale temporal encoding.""" def __init__(self, embed_dim: int, max_periods: int = 4): super().__init__() self.time_delta_proj = nn.Linear(1, embed_dim) periods = [3600, 86400, 604800, 2592000][:max_periods] self.register_buffer('periods', torch.tensor(periods, dtype=torch.float32)) self.periodic_proj = nn.Linear(max_periods * 2, embed_dim) self.layernorm = nn.LayerNorm(embed_dim) def forward(self, timestamps: torch.Tensor) -> torch.Tensor: B, T = timestamps.shape time_deltas = torch.zeros_like(timestamps) time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1] time_deltas = time_deltas.clamp(min=0) log_deltas = torch.log1p(time_deltas).unsqueeze(-1) delta_emb = self.time_delta_proj(log_deltas) ts_expanded = timestamps.unsqueeze(-1) periods = self.periods.view(1, 1, -1) angles = 2 * math.pi * ts_expanded / periods periodic_features = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) periodic_emb = self.periodic_proj(periodic_features) return self.layernorm(delta_emb + periodic_emb) class TemporalGatedLinearAttention(nn.Module): """ Temporal-Gated Linear Attention: O(n) attention with temporal decay. Uses the kernel trick: softmax(QK^T)V ≈ φ(Q) * (φ(K)^T * V) where φ is ELU + 1, making it O(n*d²) instead of O(n²*d). Added temporal gating: each step's contribution is weighted by a learnable temporal decay function. """ def __init__(self, embed_dim: int, num_heads: int = 2, dropout: float = 0.1): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) # Temporal decay: learned per head self.decay_proj = nn.Linear(1, num_heads) # log-delta → per-head decay weight self.norm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) # FFN self.ffn = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(dropout), ) def _feature_map(self, x): """ELU + 1 feature map for linear attention.""" return F.elu(x) + 1 def forward(self, x, timestamps=None, mask=None): B, T, D = x.shape H = self.num_heads d = self.head_dim # Project and reshape q = self._feature_map(self.q_proj(x)).view(B, T, H, d) k = self._feature_map(self.k_proj(x)).view(B, T, H, d) v = self.v_proj(x).view(B, T, H, d) # Temporal decay weights if timestamps is not None: time_deltas = torch.zeros_like(timestamps) time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1] time_deltas = time_deltas.clamp(min=0) log_deltas = torch.log1p(time_deltas / 3600.0).unsqueeze(-1) # (B, T, 1) decay_weights = torch.sigmoid(self.decay_proj(log_deltas)) # (B, T, H) # Weight keys by temporal decay k = k * decay_weights.unsqueeze(-1) # (B, T, H, d) # Mask padding if mask is not None: mask_expanded = mask.unsqueeze(-1).unsqueeze(-1).float() # (B, T, 1, 1) k = k * mask_expanded v = v * mask_expanded # Linear attention: O(n*d²) # Causal version using cumulative sum # KV = cumsum(k ⊗ v) → (B, T, H, d, d) — too expensive # Instead, use the simpler cumulative state approach: # Non-causal linear attention (bidirectional for long-term modeling) # attn = φ(Q)(φ(K)^T V) / φ(Q)(φ(K)^T 1) kv = torch.einsum('bthd,bthe->bhde', k, v) # (B, H, d, d) k_sum = k.sum(dim=1) # (B, H, d) # Output: q @ kv / (q @ k_sum) numerator = torch.einsum('bthd,bhde->bthe', q, kv) # (B, T, H, d) denominator = torch.einsum('bthd,bhd->bth', q, k_sum).unsqueeze(-1) # (B, T, H, 1) attn_out = numerator / (denominator + 1e-6) attn_out = attn_out.reshape(B, T, D) attn_out = self.out_proj(self.dropout(attn_out)) # Residual + LayerNorm x = self.norm(x + attn_out) # FFN with residual x = x + self.ffn(x) return x class CompressiveMemory(nn.Module): """Cross-attention memory compression.""" def __init__(self, embed_dim: int, num_memory_tokens: int = 8, num_heads: int = 2): super().__init__() self.memory_queries = nn.Parameter(torch.randn(num_memory_tokens, embed_dim) * 0.02) self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=0.1) self.ffn = nn.Sequential( nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(0.1), nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(0.1), ) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) def forward(self, sequence, mask=None): B = sequence.shape[0] queries = self.memory_queries.unsqueeze(0).expand(B, -1, -1) key_padding_mask = ~mask if mask is not None else None attn_out, _ = self.cross_attn(queries, sequence, sequence, key_padding_mask=key_padding_mask) memory = self.norm1(queries + attn_out) memory = self.norm2(memory + self.ffn(memory)) return memory class AdaptiveFusionGate(nn.Module): """Learned fusion of long-term and short-term signals.""" def __init__(self, embed_dim: int): super().__init__() self.gate = nn.Sequential( nn.Linear(embed_dim * 3, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim), nn.Sigmoid() ) def forward(self, long_term, short_term, memory): g = self.gate(torch.cat([long_term, short_term, memory], dim=-1)) return g * long_term + (1 - g) * short_term class MARSv2(nn.Module): """ MARS v2: Multi-scale Adaptive Recurrence with State compression Uses temporal-gated linear attention (O(n)) for long-term branch and standard causal self-attention for short-term branch. """ def __init__( self, num_items: int, embed_dim: int = 64, max_seq_len: int = 512, short_term_len: int = 50, num_memory_tokens: int = 8, num_long_layers: int = 3, num_short_layers: int = 2, num_heads: int = 2, dropout: float = 0.1, ): super().__init__() self.num_items = num_items self.embed_dim = embed_dim self.max_seq_len = max_seq_len self.short_term_len = short_term_len self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0) self.temporal_encoding = TemporalEncoding(embed_dim) self.position_embedding = nn.Embedding(max_seq_len, embed_dim) self.input_norm = nn.LayerNorm(embed_dim) self.input_dropout = nn.Dropout(dropout) # Long-term branch: temporal-gated linear attention (O(n)) self.long_layers = nn.ModuleList([ TemporalGatedLinearAttention(embed_dim, num_heads, dropout) for _ in range(num_long_layers) ]) # Compressive memory self.compressive_memory = CompressiveMemory(embed_dim, num_memory_tokens, num_heads) # Short-term branch: standard causal attention encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=dropout, activation='gelu', batch_first=True, norm_first=True ) self.short_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_short_layers) # Fusion self.fusion_gate = AdaptiveFusionGate(embed_dim) self.output_norm = nn.LayerNorm(embed_dim) self.output_proj = nn.Linear(embed_dim, embed_dim) self._init_weights() def _init_weights(self): for name, param in self.named_parameters(): if 'weight' in name and param.dim() >= 2: nn.init.trunc_normal_(param, std=0.02) elif 'bias' in name: nn.init.zeros_(param) nn.init.zeros_(self.item_embedding.weight[0]) @property def item_embeddings(self): return self.item_embedding def encode(self, item_ids, timestamps=None, mask=None): B, T = item_ids.shape if mask is None: mask = (item_ids != 0) # Embeddings item_emb = self.item_embedding(item_ids) if timestamps is not None: item_emb = item_emb + self.temporal_encoding(timestamps.float()) positions = torch.arange(T, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1) item_emb = self.input_norm(item_emb + self.position_embedding(positions)) item_emb = self.input_dropout(item_emb) # Long-term branch long_repr = item_emb for layer in self.long_layers: long_repr = layer(long_repr, timestamps, mask) # Memory compression memory = self.compressive_memory(long_repr, mask) memory_summary = memory.mean(dim=1) # Last valid long-term lengths = mask.sum(dim=1).long() long_last = long_repr[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)] # Short-term branch: extract last K valid items K = min(self.short_term_len, T) short_ids_list, short_ts_list, short_mask_list = [], [], [] for b in range(B): sl = lengths[b].item() actual_k = min(K, sl) start = max(0, sl - K) ids = item_ids[b, start:sl] pad = K - actual_k if pad > 0: ids = torch.cat([ids, torch.zeros(pad, dtype=ids.dtype, device=ids.device)]) short_ids_list.append(ids) if timestamps is not None: ts = timestamps[b, start:sl] if pad > 0: ts = torch.cat([ts, torch.zeros(pad, dtype=ts.dtype, device=ts.device)]) short_ts_list.append(ts) m = torch.zeros(K, dtype=torch.bool, device=item_ids.device) m[:actual_k] = True short_mask_list.append(m) short_ids = torch.stack(short_ids_list) short_mask = torch.stack(short_mask_list) short_emb = self.item_embedding(short_ids) if timestamps is not None: short_ts = torch.stack(short_ts_list) short_emb = short_emb + self.temporal_encoding(short_ts.float()) short_pos = torch.arange(K, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1) short_emb = self.input_norm(short_emb + self.position_embedding(short_pos)) causal_mask = torch.triu(torch.ones(K, K, device=item_ids.device, dtype=torch.bool), diagonal=1) short_repr = self.short_encoder(short_emb, mask=causal_mask, src_key_padding_mask=~short_mask) short_lengths = short_mask.sum(dim=1).long() short_last = short_repr[torch.arange(B, device=item_ids.device), (short_lengths - 1).clamp(min=0)] # Fusion user_emb = self.fusion_gate(long_last, short_last, memory_summary) return self.output_proj(self.output_norm(user_emb)) def forward(self, batch): if self.training: item_ids = batch['item_ids'] timestamps = batch.get('timestamps') mask = batch.get('mask') pos_ids = batch['positive_ids'] neg_ids = batch['negative_ids'] user_emb = self.encode(item_ids, timestamps, mask) pos_emb = self.item_embedding(pos_ids) neg_emb = self.item_embedding(neg_ids) pos_scores = (user_emb * pos_emb).sum(dim=-1) neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores)) loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores)) return loss_pos + loss_neg else: return self.encode(batch['item_ids'], batch.get('timestamps'), batch.get('mask')) class SASRecBaseline(nn.Module): """SASRec baseline.""" def __init__(self, num_items, embed_dim=64, max_seq_len=200, num_heads=2, num_layers=2, dropout=0.1): super().__init__() self.num_items = num_items self.embed_dim = embed_dim self.max_seq_len = max_seq_len self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0) self.position_embedding = nn.Embedding(max_seq_len, embed_dim) self.input_norm = nn.LayerNorm(embed_dim) self.input_dropout = nn.Dropout(dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=dropout, activation='gelu', batch_first=True, norm_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.output_norm = nn.LayerNorm(embed_dim) self._init_weights() def _init_weights(self): for name, param in self.named_parameters(): if 'weight' in name and param.dim() >= 2: nn.init.trunc_normal_(param, std=0.02) elif 'bias' in name: nn.init.zeros_(param) nn.init.zeros_(self.item_embedding.weight[0]) @property def item_embeddings(self): return self.item_embedding def encode(self, item_ids, timestamps=None, mask=None): B, T = item_ids.shape if mask is None: mask = (item_ids != 0) item_emb = self.item_embedding(item_ids) positions = torch.arange(T, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1) item_emb = self.input_norm(item_emb + self.position_embedding(positions)) item_emb = self.input_dropout(item_emb) causal_mask = torch.triu(torch.ones(T, T, device=item_ids.device, dtype=torch.bool), diagonal=1) output = self.encoder(item_emb, mask=causal_mask, src_key_padding_mask=~mask) lengths = mask.sum(dim=1).long() user_emb = output[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)] return self.output_norm(user_emb) def forward(self, batch): if self.training: item_ids = batch['item_ids'] mask = batch.get('mask') pos_ids = batch['positive_ids'] neg_ids = batch['negative_ids'] user_emb = self.encode(item_ids, mask=mask) pos_emb = self.item_embedding(pos_ids) neg_emb = self.item_embedding(neg_ids) pos_scores = (user_emb * pos_emb).sum(dim=-1) neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores)) loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores)) return loss_pos + loss_neg else: return self.encode(batch['item_ids'], mask=batch.get('mask'))