| """ |
| MARS: Multi-scale Adaptive Recurrence with State compression |
| ============================================================ |
| |
| An innovative method for super long sequence modeling in sequential recommendation. |
| |
| Key innovations: |
| 1. Temporal-Aware Delta Network (TADN) for O(n) long-range modeling |
| - Explicit exponential temporal decay in state updates |
| - Input-dependent gating for selective memory retention |
| |
| 2. Compressive Memory Tokens |
| - Fixed-size learnable memory that compresses arbitrarily long histories |
| - Acts as information bottleneck (denoising effect per Rec2PM) |
| |
| 3. Dual-Branch Architecture with Learned Fusion |
| - Long-term branch: TADN layers processing full history at O(n) cost |
| - Short-term branch: Standard self-attention on recent K interactions |
| - Adaptive gating fusion that balances long/short-term signals per user |
| |
| 4. Multi-Scale Temporal Encoding |
| - Absolute time embeddings + relative time deltas + periodic components |
| - Captures daily/weekly/seasonal patterns in user behavior |
| |
| This combines ideas from HyTRec (2602.18283), Rec2PM (2602.11605), |
| SIGMA (2408.11451), and HSTU (2402.17152) into a unified architecture. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, Dict |
|
|
|
|
| class TemporalEncoding(nn.Module): |
| """Multi-scale temporal encoding with periodic components. |
| |
| Captures absolute time, relative time deltas, and periodic patterns |
| (daily, weekly cycles) in user behavior. |
| """ |
| |
| def __init__(self, embed_dim: int, max_periods: int = 4): |
| super().__init__() |
| self.embed_dim = embed_dim |
| |
| |
| self.time_delta_proj = nn.Linear(1, embed_dim) |
| |
| |
| periods = [3600, 86400, 604800, 2592000][:max_periods] |
| self.register_buffer('periods', torch.tensor(periods, dtype=torch.float32)) |
| self.periodic_proj = nn.Linear(max_periods * 2, embed_dim) |
| |
| |
| self.layernorm = nn.LayerNorm(embed_dim) |
| |
| def forward(self, timestamps: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| timestamps: (batch, seq_len) absolute timestamps in seconds |
| Returns: |
| temporal_emb: (batch, seq_len, embed_dim) |
| """ |
| B, T = timestamps.shape |
| |
| |
| time_deltas = torch.zeros_like(timestamps) |
| time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1] |
| time_deltas = time_deltas.clamp(min=0) |
| |
| log_deltas = torch.log1p(time_deltas).unsqueeze(-1) |
| delta_emb = self.time_delta_proj(log_deltas) |
| |
| |
| ts_expanded = timestamps.unsqueeze(-1) |
| periods = self.periods.view(1, 1, -1) |
| angles = 2 * math.pi * ts_expanded / periods |
| periodic_features = torch.cat([ |
| torch.sin(angles), |
| torch.cos(angles) |
| ], dim=-1) |
| periodic_emb = self.periodic_proj(periodic_features) |
| |
| |
| temporal_emb = self.layernorm(delta_emb + periodic_emb) |
| return temporal_emb |
|
|
|
|
| class TADNLayer(nn.Module): |
| """Temporal-Aware Delta Network Layer. |
| |
| Linear complexity O(n) recurrent layer with: |
| - Delta rule state updates (inspired by HyTRec) |
| - Explicit temporal decay gating |
| - Input-dependent selective memory |
| |
| The state matrix S is updated as: |
| S_t = S_{t-1} * (I - g_t * beta_t * k_t * k_t^T) + beta_t * v_t * k_t^T |
| |
| where g_t incorporates temporal decay: |
| g_t = alpha * sigmoid(W_g * [h_t, delta_h_t]) * tau_t + (1-alpha) * g_static |
| tau_t = exp(-(t_current - t_behavior) / T) |
| """ |
| |
| def __init__(self, embed_dim: int, state_dim: int = 64, dropout: float = 0.1): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.state_dim = state_dim |
| |
| |
| self.q_proj = nn.Linear(embed_dim, state_dim) |
| self.k_proj = nn.Linear(embed_dim, state_dim) |
| self.v_proj = nn.Linear(embed_dim, embed_dim) |
| |
| |
| self.gate_proj = nn.Linear(embed_dim * 2, embed_dim) |
| self.beta_proj = nn.Linear(embed_dim, state_dim) |
| |
| |
| self.alpha = nn.Parameter(torch.tensor(0.5)) |
| self.time_scale = nn.Parameter(torch.tensor(1.0)) |
| |
| |
| self.gate_static = nn.Parameter(torch.ones(embed_dim) * 0.5) |
| |
| |
| self.out_proj = nn.Linear(embed_dim, embed_dim) |
| self.layernorm = nn.LayerNorm(embed_dim) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| timestamps: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: (batch, seq_len, embed_dim) input sequence |
| timestamps: (batch, seq_len) timestamps for temporal decay |
| mask: (batch, seq_len) boolean mask (True = valid) |
| Returns: |
| output: (batch, seq_len, embed_dim) |
| """ |
| B, T, D = x.shape |
| |
| |
| q = self.q_proj(x) |
| k = self.k_proj(x) |
| v = self.v_proj(x) |
| |
| |
| beta = torch.sigmoid(self.beta_proj(x)) |
| |
| |
| if timestamps is not None: |
| |
| |
| |
| t_last = timestamps[:, -1:].unsqueeze(-1) |
| t_behavior = timestamps.unsqueeze(-1) |
| time_delta = (t_last - t_behavior).clamp(min=0) |
| |
| |
| log_delta = torch.log1p(time_delta / 3600.0) |
| |
| |
| tau = torch.exp( |
| -log_delta / (torch.abs(self.time_scale) * 10.0 + 1.0) |
| ) |
| else: |
| |
| positions = torch.arange(T, device=x.device).float() |
| tau = torch.exp(-positions / (T + 1e-6)).view(1, T, 1) |
| |
| |
| |
| x_shifted = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) |
| delta_x = x - x_shifted |
| gate_input = torch.cat([x, delta_x], dim=-1) |
| |
| alpha = torch.sigmoid(self.alpha) |
| g_dynamic = torch.sigmoid(self.gate_proj(gate_input)) |
| g = alpha * g_dynamic * tau + (1 - alpha) * torch.sigmoid(self.gate_static) |
| |
| |
| |
| chunk_size = min(64, T) |
| |
| outputs = [] |
| S = torch.zeros(B, self.state_dim, D, device=x.device) |
| |
| for chunk_start in range(0, T, chunk_size): |
| chunk_end = min(chunk_start + chunk_size, T) |
| |
| for t in range(chunk_start, chunk_end): |
| k_t = k[:, t] |
| v_t = v[:, t] |
| beta_t = beta[:, t] |
| g_t = g[:, t] |
| q_t = q[:, t] |
| |
| |
| |
| erase = torch.einsum('bs,bd->bsd', beta_t * k_t, g_t).clamp(0, 1) |
| write = torch.einsum('bs,bd->bsd', beta_t * k_t, v_t) |
| |
| if mask is not None: |
| valid = mask[:, t].float().view(B, 1, 1) |
| S = S * (1 - erase * valid) + write * valid |
| else: |
| S = S * (1 - erase) + write |
| |
| |
| S = S.clamp(-10, 10) |
| |
| |
| out_t = torch.einsum('bs,bsd->bd', q_t, S) |
| outputs.append(out_t) |
| |
| output = torch.stack(outputs, dim=1) |
| output = self.out_proj(self.dropout(output)) |
| output = self.layernorm(x + output) |
| |
| return output |
|
|
|
|
| class CompressiveMemory(nn.Module): |
| """Compressive Memory Module. |
| |
| Compresses long sequence history into a fixed number of memory tokens. |
| Acts as information bottleneck (denoising per Rec2PM theory). |
| |
| Uses cross-attention: memory queries attend to sequence to extract summary. |
| """ |
| |
| def __init__(self, embed_dim: int, num_memory_tokens: int = 8, num_heads: int = 2): |
| super().__init__() |
| self.num_memory_tokens = num_memory_tokens |
| |
| |
| self.memory_queries = nn.Parameter( |
| torch.randn(num_memory_tokens, embed_dim) * 0.02 |
| ) |
| |
| |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| batch_first=True, |
| dropout=0.1 |
| ) |
| |
| self.ffn = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim * 4), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(embed_dim * 4, embed_dim), |
| nn.Dropout(0.1), |
| ) |
| |
| self.norm1 = nn.LayerNorm(embed_dim) |
| self.norm2 = nn.LayerNorm(embed_dim) |
| |
| def forward( |
| self, |
| sequence: torch.Tensor, |
| mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| sequence: (batch, seq_len, embed_dim) - encoded sequence |
| mask: (batch, seq_len) boolean mask (True = valid, False = padding) |
| Returns: |
| memory: (batch, num_memory_tokens, embed_dim) |
| """ |
| B = sequence.shape[0] |
| |
| |
| queries = self.memory_queries.unsqueeze(0).expand(B, -1, -1) |
| |
| |
| |
| if mask is not None: |
| key_padding_mask = ~mask |
| else: |
| key_padding_mask = None |
| |
| attn_out, _ = self.cross_attn( |
| query=queries, |
| key=sequence, |
| value=sequence, |
| key_padding_mask=key_padding_mask |
| ) |
| memory = self.norm1(queries + attn_out) |
| memory = self.norm2(memory + self.ffn(memory)) |
| |
| return memory |
|
|
|
|
| class ShortTermAttention(nn.Module): |
| """Standard self-attention block for short-term (recent) interactions. |
| |
| Uses standard causal multi-head attention — full expressiveness |
| for the most recent K items where O(K²) is acceptable. |
| """ |
| |
| def __init__(self, embed_dim: int, num_heads: int = 2, num_layers: int = 2, dropout: float = 0.1): |
| super().__init__() |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=embed_dim, |
| nhead=num_heads, |
| dim_feedforward=embed_dim * 4, |
| dropout=dropout, |
| activation='gelu', |
| batch_first=True, |
| norm_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: (batch, K, embed_dim) recent interactions |
| mask: (batch, K) boolean mask |
| Returns: |
| output: (batch, K, embed_dim) |
| """ |
| T = x.shape[1] |
| |
| |
| causal_mask = torch.triu( |
| torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1 |
| ) |
| |
| |
| src_key_padding_mask = ~mask if mask is not None else None |
| |
| output = self.encoder( |
| x, |
| mask=causal_mask, |
| src_key_padding_mask=src_key_padding_mask |
| ) |
| return output |
|
|
|
|
| class AdaptiveFusionGate(nn.Module): |
| """Adaptive fusion gate that balances long-term and short-term signals. |
| |
| Per-user, per-timestep gating: |
| output = sigma(gate) * long_term + (1 - sigma(gate)) * short_term |
| """ |
| |
| def __init__(self, embed_dim: int): |
| super().__init__() |
| self.gate = nn.Sequential( |
| nn.Linear(embed_dim * 3, embed_dim), |
| nn.GELU(), |
| nn.Linear(embed_dim, embed_dim), |
| nn.Sigmoid() |
| ) |
| |
| def forward( |
| self, |
| long_term: torch.Tensor, |
| short_term: torch.Tensor, |
| memory: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Args: |
| long_term: (batch, embed_dim) |
| short_term: (batch, embed_dim) |
| memory: (batch, embed_dim) compressed memory summary |
| Returns: |
| fused: (batch, embed_dim) |
| """ |
| gate_input = torch.cat([long_term, short_term, memory], dim=-1) |
| g = self.gate(gate_input) |
| return g * long_term + (1 - g) * short_term |
|
|
|
|
| class MARS(nn.Module): |
| """ |
| MARS: Multi-scale Adaptive Recurrence with State compression |
| |
| Architecture: |
| Input: Full user interaction sequence + timestamps |
| | |
| v |
| [Item Embedding + Temporal Encoding] |
| | |
| +---- Long-term Branch (TADN layers, O(n)) |
| | | |
| | [Compressive Memory] → memory tokens |
| | | |
| +---- Short-term Branch (Self-Attention on recent K items) |
| | |
| v |
| [Adaptive Fusion Gate] |
| | |
| v |
| [Prediction Head] → next item scores |
| |
| Args: |
| num_items: number of unique items |
| embed_dim: embedding dimension |
| max_seq_len: maximum sequence length (can be very long, e.g. 2048) |
| short_term_len: number of recent items for short-term branch |
| num_memory_tokens: number of compressive memory tokens |
| num_tadn_layers: number of TADN layers in long-term branch |
| num_attn_layers: number of attention layers in short-term branch |
| num_heads: number of attention heads |
| state_dim: state dimension for TADN |
| dropout: dropout rate |
| """ |
| |
| def __init__( |
| self, |
| num_items: int, |
| embed_dim: int = 64, |
| max_seq_len: int = 512, |
| short_term_len: int = 50, |
| num_memory_tokens: int = 8, |
| num_tadn_layers: int = 3, |
| num_attn_layers: int = 2, |
| num_heads: int = 2, |
| state_dim: int = 64, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.num_items = num_items |
| self.embed_dim = embed_dim |
| self.max_seq_len = max_seq_len |
| self.short_term_len = short_term_len |
| self.num_memory_tokens = num_memory_tokens |
| |
| |
| self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0) |
| |
| |
| self.temporal_encoding = TemporalEncoding(embed_dim) |
| |
| |
| self.position_embedding = nn.Embedding(max_seq_len, embed_dim) |
| |
| |
| self.input_norm = nn.LayerNorm(embed_dim) |
| self.input_dropout = nn.Dropout(dropout) |
| |
| |
| self.tadn_layers = nn.ModuleList([ |
| TADNLayer(embed_dim, state_dim, dropout) |
| for _ in range(num_tadn_layers) |
| ]) |
| |
| |
| self.compressive_memory = CompressiveMemory( |
| embed_dim, num_memory_tokens, num_heads |
| ) |
| |
| |
| self.short_term_attn = ShortTermAttention( |
| embed_dim, num_heads, num_attn_layers, dropout |
| ) |
| |
| |
| self.fusion_gate = AdaptiveFusionGate(embed_dim) |
| |
| |
| self.output_norm = nn.LayerNorm(embed_dim) |
| self.output_proj = nn.Linear(embed_dim, embed_dim) |
| |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| """Initialize with truncated normal distribution.""" |
| for name, param in self.named_parameters(): |
| if 'weight' in name and param.dim() >= 2: |
| nn.init.trunc_normal_(param, std=0.02) |
| elif 'bias' in name: |
| nn.init.zeros_(param) |
| |
| |
| nn.init.trunc_normal_(self.item_embedding.weight, std=0.02) |
| nn.init.zeros_(self.item_embedding.weight[0]) |
| |
| @property |
| def item_embeddings(self): |
| """Access item embedding table (for evaluation).""" |
| return self.item_embedding |
| |
| def encode( |
| self, |
| item_ids: torch.Tensor, |
| timestamps: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Encode a full sequence into user representations. |
| |
| Args: |
| item_ids: (batch, seq_len) item indices (0 = padding) |
| timestamps: (batch, seq_len) timestamps in seconds |
| mask: (batch, seq_len) boolean mask (True = valid) |
| Returns: |
| user_emb: (batch, embed_dim) final user representation |
| """ |
| B, T = item_ids.shape |
| |
| |
| if mask is None: |
| mask = (item_ids != 0) |
| |
| |
| item_emb = self.item_embedding(item_ids) |
| |
| if timestamps is not None: |
| temp_emb = self.temporal_encoding(timestamps.float()) |
| item_emb = item_emb + temp_emb |
| |
| |
| positions = torch.arange(T, device=item_ids.device).unsqueeze(0) |
| positions = positions.clamp(max=self.max_seq_len - 1) |
| pos_emb = self.position_embedding(positions) |
| |
| item_emb = self.input_norm(item_emb + pos_emb) |
| item_emb = self.input_dropout(item_emb) |
| |
| |
| long_term_repr = item_emb |
| for tadn in self.tadn_layers: |
| long_term_repr = tadn(long_term_repr, timestamps, mask) |
| |
| |
| memory = self.compressive_memory(long_term_repr, mask) |
| memory_summary = memory.mean(dim=1) |
| |
| |
| |
| lengths = mask.sum(dim=1).long() |
| long_term_last = long_term_repr[ |
| torch.arange(B, device=item_ids.device), |
| (lengths - 1).clamp(min=0) |
| ] |
| |
| |
| |
| |
| K = min(self.short_term_len, T) |
| |
| |
| short_item_ids_list = [] |
| short_ts_list = [] |
| short_mask_list = [] |
| |
| for b in range(B): |
| seq_len = lengths[b].item() |
| actual_k = min(K, seq_len) |
| start = max(0, seq_len - K) |
| end = seq_len |
| |
| |
| ids = item_ids[b, start:end] |
| pad_len = K - actual_k |
| if pad_len > 0: |
| ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype, device=ids.device)]) |
| short_item_ids_list.append(ids) |
| |
| if timestamps is not None: |
| ts = timestamps[b, start:end] |
| if pad_len > 0: |
| ts = torch.cat([ts, torch.zeros(pad_len, dtype=ts.dtype, device=ts.device)]) |
| short_ts_list.append(ts) |
| |
| m = torch.zeros(K, dtype=torch.bool, device=item_ids.device) |
| m[:actual_k] = True |
| short_mask_list.append(m) |
| |
| short_item_ids = torch.stack(short_item_ids_list) |
| short_mask = torch.stack(short_mask_list) |
| |
| short_emb = self.item_embedding(short_item_ids) |
| |
| if timestamps is not None: |
| short_ts = torch.stack(short_ts_list) |
| short_temp = self.temporal_encoding(short_ts.float()) |
| short_emb = short_emb + short_temp |
| |
| short_positions = torch.arange(K, device=item_ids.device).unsqueeze(0) |
| short_positions = short_positions.clamp(max=self.max_seq_len - 1) |
| short_emb = short_emb + self.position_embedding(short_positions) |
| short_emb = self.input_norm(short_emb) |
| |
| short_term_repr = self.short_term_attn(short_emb, short_mask) |
| |
| |
| short_lengths = short_mask.sum(dim=1).long() |
| short_term_last = short_term_repr[ |
| torch.arange(B, device=item_ids.device), |
| (short_lengths - 1).clamp(min=0) |
| ] |
| |
| |
| user_emb = self.fusion_gate(long_term_last, short_term_last, memory_summary) |
| user_emb = self.output_proj(self.output_norm(user_emb)) |
| |
| return user_emb |
| |
| def forward( |
| self, |
| batch: Dict[str, torch.Tensor] |
| ) -> torch.Tensor: |
| """ |
| Training forward pass. |
| |
| Expected batch format (flat tensors, matching Yambda convention): |
| - item_ids: (batch, max_seq_len) padded item sequences |
| - timestamps: (batch, max_seq_len) padded timestamps |
| - mask: (batch, max_seq_len) boolean mask |
| - positive_ids: (batch,) positive next items |
| - negative_ids: (batch, num_neg) negative items |
| |
| Returns: |
| loss: scalar BCE loss |
| """ |
| if self.training: |
| return self._training_forward(batch) |
| else: |
| return self._eval_forward(batch) |
| |
| def _training_forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """Compute training loss with next-item prediction.""" |
| item_ids = batch['item_ids'] |
| timestamps = batch.get('timestamps') |
| mask = batch.get('mask') |
| pos_ids = batch['positive_ids'] |
| neg_ids = batch['negative_ids'] |
| |
| |
| user_emb = self.encode(item_ids, timestamps, mask) |
| |
| |
| pos_emb = self.item_embedding(pos_ids) |
| neg_emb = self.item_embedding(neg_ids) |
| |
| pos_scores = (user_emb * pos_emb).sum(dim=-1) |
| neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) |
| |
| |
| pos_labels = torch.ones_like(pos_scores) |
| neg_labels = torch.zeros_like(neg_scores) |
| |
| loss_pos = F.binary_cross_entropy_with_logits(pos_scores, pos_labels) |
| loss_neg = F.binary_cross_entropy_with_logits(neg_scores, neg_labels) |
| |
| loss = loss_pos + loss_neg |
| return loss |
| |
| def _eval_forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """Eval forward: returns user embeddings.""" |
| item_ids = batch['item_ids'] |
| timestamps = batch.get('timestamps') |
| mask = batch.get('mask') |
| |
| user_emb = self.encode(item_ids, timestamps, mask) |
| return user_emb |
|
|
|
|
| class SASRecBaseline(nn.Module): |
| """ |
| Standard SASRec baseline for comparison. |
| Uses causal self-attention (O(n²) complexity). |
| """ |
| |
| def __init__( |
| self, |
| num_items: int, |
| embed_dim: int = 64, |
| max_seq_len: int = 200, |
| num_heads: int = 2, |
| num_layers: int = 2, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.num_items = num_items |
| self.embed_dim = embed_dim |
| self.max_seq_len = max_seq_len |
| |
| self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0) |
| self.position_embedding = nn.Embedding(max_seq_len, embed_dim) |
| |
| self.input_norm = nn.LayerNorm(embed_dim) |
| self.input_dropout = nn.Dropout(dropout) |
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=embed_dim, |
| nhead=num_heads, |
| dim_feedforward=embed_dim * 4, |
| dropout=dropout, |
| activation='gelu', |
| batch_first=True, |
| norm_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| |
| self.output_norm = nn.LayerNorm(embed_dim) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| for name, param in self.named_parameters(): |
| if 'weight' in name and param.dim() >= 2: |
| nn.init.trunc_normal_(param, std=0.02) |
| elif 'bias' in name: |
| nn.init.zeros_(param) |
| nn.init.zeros_(self.item_embedding.weight[0]) |
| |
| @property |
| def item_embeddings(self): |
| return self.item_embedding |
| |
| def encode(self, item_ids, timestamps=None, mask=None): |
| B, T = item_ids.shape |
| if mask is None: |
| mask = (item_ids != 0) |
| |
| item_emb = self.item_embedding(item_ids) |
| positions = torch.arange(T, device=item_ids.device).unsqueeze(0) |
| positions = positions.clamp(max=self.max_seq_len - 1) |
| item_emb = item_emb + self.position_embedding(positions) |
| item_emb = self.input_norm(item_emb) |
| item_emb = self.input_dropout(item_emb) |
| |
| causal_mask = torch.triu(torch.ones(T, T, device=item_ids.device, dtype=torch.bool), diagonal=1) |
| src_key_padding_mask = ~mask |
| |
| output = self.encoder(item_emb, mask=causal_mask, src_key_padding_mask=src_key_padding_mask) |
| |
| lengths = mask.sum(dim=1).long() |
| user_emb = output[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)] |
| user_emb = self.output_norm(user_emb) |
| |
| return user_emb |
| |
| def forward(self, batch): |
| if self.training: |
| item_ids = batch['item_ids'] |
| timestamps = batch.get('timestamps') |
| mask = batch.get('mask') |
| pos_ids = batch['positive_ids'] |
| neg_ids = batch['negative_ids'] |
| |
| user_emb = self.encode(item_ids, timestamps, mask) |
| pos_emb = self.item_embedding(pos_ids) |
| neg_emb = self.item_embedding(neg_ids) |
| |
| pos_scores = (user_emb * pos_emb).sum(dim=-1) |
| neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) |
| |
| loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores)) |
| loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores)) |
| |
| return loss_pos + loss_neg |
| else: |
| return self.encode(batch['item_ids'], batch.get('timestamps'), batch.get('mask')) |
|
|