MARS-SeqRec / model_v2.py
CyberDancer's picture
MARS v2: Temporal-Gated Linear Attention for SeqRec
3989f8c verified
"""
MARS v2: Simplified and stabilized architecture.
Key changes from v1:
1. Replace unstable delta-rule state with temporal-gated linear attention
2. Simpler but more robust long-term branch
3. FFN layers for capacity
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict
class TemporalEncoding(nn.Module):
"""Multi-scale temporal encoding."""
def __init__(self, embed_dim: int, max_periods: int = 4):
super().__init__()
self.time_delta_proj = nn.Linear(1, embed_dim)
periods = [3600, 86400, 604800, 2592000][:max_periods]
self.register_buffer('periods', torch.tensor(periods, dtype=torch.float32))
self.periodic_proj = nn.Linear(max_periods * 2, embed_dim)
self.layernorm = nn.LayerNorm(embed_dim)
def forward(self, timestamps: torch.Tensor) -> torch.Tensor:
B, T = timestamps.shape
time_deltas = torch.zeros_like(timestamps)
time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1]
time_deltas = time_deltas.clamp(min=0)
log_deltas = torch.log1p(time_deltas).unsqueeze(-1)
delta_emb = self.time_delta_proj(log_deltas)
ts_expanded = timestamps.unsqueeze(-1)
periods = self.periods.view(1, 1, -1)
angles = 2 * math.pi * ts_expanded / periods
periodic_features = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
periodic_emb = self.periodic_proj(periodic_features)
return self.layernorm(delta_emb + periodic_emb)
class TemporalGatedLinearAttention(nn.Module):
"""
Temporal-Gated Linear Attention: O(n) attention with temporal decay.
Uses the kernel trick: softmax(QK^T)V ≈ φ(Q) * (φ(K)^T * V)
where φ is ELU + 1, making it O(n*d²) instead of O(n²*d).
Added temporal gating: each step's contribution is weighted by
a learnable temporal decay function.
"""
def __init__(self, embed_dim: int, num_heads: int = 2, dropout: float = 0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
# Temporal decay: learned per head
self.decay_proj = nn.Linear(1, num_heads) # log-delta → per-head decay weight
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
# FFN
self.ffn = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(embed_dim * 4, embed_dim),
nn.Dropout(dropout),
)
def _feature_map(self, x):
"""ELU + 1 feature map for linear attention."""
return F.elu(x) + 1
def forward(self, x, timestamps=None, mask=None):
B, T, D = x.shape
H = self.num_heads
d = self.head_dim
# Project and reshape
q = self._feature_map(self.q_proj(x)).view(B, T, H, d)
k = self._feature_map(self.k_proj(x)).view(B, T, H, d)
v = self.v_proj(x).view(B, T, H, d)
# Temporal decay weights
if timestamps is not None:
time_deltas = torch.zeros_like(timestamps)
time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1]
time_deltas = time_deltas.clamp(min=0)
log_deltas = torch.log1p(time_deltas / 3600.0).unsqueeze(-1) # (B, T, 1)
decay_weights = torch.sigmoid(self.decay_proj(log_deltas)) # (B, T, H)
# Weight keys by temporal decay
k = k * decay_weights.unsqueeze(-1) # (B, T, H, d)
# Mask padding
if mask is not None:
mask_expanded = mask.unsqueeze(-1).unsqueeze(-1).float() # (B, T, 1, 1)
k = k * mask_expanded
v = v * mask_expanded
# Linear attention: O(n*d²)
# Causal version using cumulative sum
# KV = cumsum(k ⊗ v) → (B, T, H, d, d) — too expensive
# Instead, use the simpler cumulative state approach:
# Non-causal linear attention (bidirectional for long-term modeling)
# attn = φ(Q)(φ(K)^T V) / φ(Q)(φ(K)^T 1)
kv = torch.einsum('bthd,bthe->bhde', k, v) # (B, H, d, d)
k_sum = k.sum(dim=1) # (B, H, d)
# Output: q @ kv / (q @ k_sum)
numerator = torch.einsum('bthd,bhde->bthe', q, kv) # (B, T, H, d)
denominator = torch.einsum('bthd,bhd->bth', q, k_sum).unsqueeze(-1) # (B, T, H, 1)
attn_out = numerator / (denominator + 1e-6)
attn_out = attn_out.reshape(B, T, D)
attn_out = self.out_proj(self.dropout(attn_out))
# Residual + LayerNorm
x = self.norm(x + attn_out)
# FFN with residual
x = x + self.ffn(x)
return x
class CompressiveMemory(nn.Module):
"""Cross-attention memory compression."""
def __init__(self, embed_dim: int, num_memory_tokens: int = 8, num_heads: int = 2):
super().__init__()
self.memory_queries = nn.Parameter(torch.randn(num_memory_tokens, embed_dim) * 0.02)
self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=0.1)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Dropout(0.1),
nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(0.1),
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, sequence, mask=None):
B = sequence.shape[0]
queries = self.memory_queries.unsqueeze(0).expand(B, -1, -1)
key_padding_mask = ~mask if mask is not None else None
attn_out, _ = self.cross_attn(queries, sequence, sequence, key_padding_mask=key_padding_mask)
memory = self.norm1(queries + attn_out)
memory = self.norm2(memory + self.ffn(memory))
return memory
class AdaptiveFusionGate(nn.Module):
"""Learned fusion of long-term and short-term signals."""
def __init__(self, embed_dim: int):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(embed_dim * 3, embed_dim),
nn.GELU(),
nn.Linear(embed_dim, embed_dim),
nn.Sigmoid()
)
def forward(self, long_term, short_term, memory):
g = self.gate(torch.cat([long_term, short_term, memory], dim=-1))
return g * long_term + (1 - g) * short_term
class MARSv2(nn.Module):
"""
MARS v2: Multi-scale Adaptive Recurrence with State compression
Uses temporal-gated linear attention (O(n)) for long-term branch
and standard causal self-attention for short-term branch.
"""
def __init__(
self,
num_items: int,
embed_dim: int = 64,
max_seq_len: int = 512,
short_term_len: int = 50,
num_memory_tokens: int = 8,
num_long_layers: int = 3,
num_short_layers: int = 2,
num_heads: int = 2,
dropout: float = 0.1,
):
super().__init__()
self.num_items = num_items
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.short_term_len = short_term_len
self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0)
self.temporal_encoding = TemporalEncoding(embed_dim)
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
self.input_norm = nn.LayerNorm(embed_dim)
self.input_dropout = nn.Dropout(dropout)
# Long-term branch: temporal-gated linear attention (O(n))
self.long_layers = nn.ModuleList([
TemporalGatedLinearAttention(embed_dim, num_heads, dropout)
for _ in range(num_long_layers)
])
# Compressive memory
self.compressive_memory = CompressiveMemory(embed_dim, num_memory_tokens, num_heads)
# Short-term branch: standard causal attention
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4,
dropout=dropout, activation='gelu', batch_first=True, norm_first=True
)
self.short_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_short_layers)
# Fusion
self.fusion_gate = AdaptiveFusionGate(embed_dim)
self.output_norm = nn.LayerNorm(embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self._init_weights()
def _init_weights(self):
for name, param in self.named_parameters():
if 'weight' in name and param.dim() >= 2:
nn.init.trunc_normal_(param, std=0.02)
elif 'bias' in name:
nn.init.zeros_(param)
nn.init.zeros_(self.item_embedding.weight[0])
@property
def item_embeddings(self):
return self.item_embedding
def encode(self, item_ids, timestamps=None, mask=None):
B, T = item_ids.shape
if mask is None:
mask = (item_ids != 0)
# Embeddings
item_emb = self.item_embedding(item_ids)
if timestamps is not None:
item_emb = item_emb + self.temporal_encoding(timestamps.float())
positions = torch.arange(T, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1)
item_emb = self.input_norm(item_emb + self.position_embedding(positions))
item_emb = self.input_dropout(item_emb)
# Long-term branch
long_repr = item_emb
for layer in self.long_layers:
long_repr = layer(long_repr, timestamps, mask)
# Memory compression
memory = self.compressive_memory(long_repr, mask)
memory_summary = memory.mean(dim=1)
# Last valid long-term
lengths = mask.sum(dim=1).long()
long_last = long_repr[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)]
# Short-term branch: extract last K valid items
K = min(self.short_term_len, T)
short_ids_list, short_ts_list, short_mask_list = [], [], []
for b in range(B):
sl = lengths[b].item()
actual_k = min(K, sl)
start = max(0, sl - K)
ids = item_ids[b, start:sl]
pad = K - actual_k
if pad > 0:
ids = torch.cat([ids, torch.zeros(pad, dtype=ids.dtype, device=ids.device)])
short_ids_list.append(ids)
if timestamps is not None:
ts = timestamps[b, start:sl]
if pad > 0:
ts = torch.cat([ts, torch.zeros(pad, dtype=ts.dtype, device=ts.device)])
short_ts_list.append(ts)
m = torch.zeros(K, dtype=torch.bool, device=item_ids.device)
m[:actual_k] = True
short_mask_list.append(m)
short_ids = torch.stack(short_ids_list)
short_mask = torch.stack(short_mask_list)
short_emb = self.item_embedding(short_ids)
if timestamps is not None:
short_ts = torch.stack(short_ts_list)
short_emb = short_emb + self.temporal_encoding(short_ts.float())
short_pos = torch.arange(K, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1)
short_emb = self.input_norm(short_emb + self.position_embedding(short_pos))
causal_mask = torch.triu(torch.ones(K, K, device=item_ids.device, dtype=torch.bool), diagonal=1)
short_repr = self.short_encoder(short_emb, mask=causal_mask, src_key_padding_mask=~short_mask)
short_lengths = short_mask.sum(dim=1).long()
short_last = short_repr[torch.arange(B, device=item_ids.device), (short_lengths - 1).clamp(min=0)]
# Fusion
user_emb = self.fusion_gate(long_last, short_last, memory_summary)
return self.output_proj(self.output_norm(user_emb))
def forward(self, batch):
if self.training:
item_ids = batch['item_ids']
timestamps = batch.get('timestamps')
mask = batch.get('mask')
pos_ids = batch['positive_ids']
neg_ids = batch['negative_ids']
user_emb = self.encode(item_ids, timestamps, mask)
pos_emb = self.item_embedding(pos_ids)
neg_emb = self.item_embedding(neg_ids)
pos_scores = (user_emb * pos_emb).sum(dim=-1)
neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb)
loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
return loss_pos + loss_neg
else:
return self.encode(batch['item_ids'], batch.get('timestamps'), batch.get('mask'))
class SASRecBaseline(nn.Module):
"""SASRec baseline."""
def __init__(self, num_items, embed_dim=64, max_seq_len=200, num_heads=2, num_layers=2, dropout=0.1):
super().__init__()
self.num_items = num_items
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0)
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
self.input_norm = nn.LayerNorm(embed_dim)
self.input_dropout = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4,
dropout=dropout, activation='gelu', batch_first=True, norm_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_norm = nn.LayerNorm(embed_dim)
self._init_weights()
def _init_weights(self):
for name, param in self.named_parameters():
if 'weight' in name and param.dim() >= 2:
nn.init.trunc_normal_(param, std=0.02)
elif 'bias' in name:
nn.init.zeros_(param)
nn.init.zeros_(self.item_embedding.weight[0])
@property
def item_embeddings(self):
return self.item_embedding
def encode(self, item_ids, timestamps=None, mask=None):
B, T = item_ids.shape
if mask is None:
mask = (item_ids != 0)
item_emb = self.item_embedding(item_ids)
positions = torch.arange(T, device=item_ids.device).unsqueeze(0).clamp(max=self.max_seq_len - 1)
item_emb = self.input_norm(item_emb + self.position_embedding(positions))
item_emb = self.input_dropout(item_emb)
causal_mask = torch.triu(torch.ones(T, T, device=item_ids.device, dtype=torch.bool), diagonal=1)
output = self.encoder(item_emb, mask=causal_mask, src_key_padding_mask=~mask)
lengths = mask.sum(dim=1).long()
user_emb = output[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)]
return self.output_norm(user_emb)
def forward(self, batch):
if self.training:
item_ids = batch['item_ids']
mask = batch.get('mask')
pos_ids = batch['positive_ids']
neg_ids = batch['negative_ids']
user_emb = self.encode(item_ids, mask=mask)
pos_emb = self.item_embedding(pos_ids)
neg_emb = self.item_embedding(neg_ids)
pos_scores = (user_emb * pos_emb).sum(dim=-1)
neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb)
loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
return loss_pos + loss_neg
else:
return self.encode(batch['item_ids'], mask=batch.get('mask'))