MARS-SeqRec / model.py
CyberDancer's picture
MARS: Multi-scale Adaptive Recurrence with State compression
2319f81 verified
"""
MARS: Multi-scale Adaptive Recurrence with State compression
============================================================
An innovative method for super long sequence modeling in sequential recommendation.
Key innovations:
1. Temporal-Aware Delta Network (TADN) for O(n) long-range modeling
- Explicit exponential temporal decay in state updates
- Input-dependent gating for selective memory retention
2. Compressive Memory Tokens
- Fixed-size learnable memory that compresses arbitrarily long histories
- Acts as information bottleneck (denoising effect per Rec2PM)
3. Dual-Branch Architecture with Learned Fusion
- Long-term branch: TADN layers processing full history at O(n) cost
- Short-term branch: Standard self-attention on recent K interactions
- Adaptive gating fusion that balances long/short-term signals per user
4. Multi-Scale Temporal Encoding
- Absolute time embeddings + relative time deltas + periodic components
- Captures daily/weekly/seasonal patterns in user behavior
This combines ideas from HyTRec (2602.18283), Rec2PM (2602.11605),
SIGMA (2408.11451), and HSTU (2402.17152) into a unified architecture.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
class TemporalEncoding(nn.Module):
"""Multi-scale temporal encoding with periodic components.
Captures absolute time, relative time deltas, and periodic patterns
(daily, weekly cycles) in user behavior.
"""
def __init__(self, embed_dim: int, max_periods: int = 4):
super().__init__()
self.embed_dim = embed_dim
# Relative time delta projection
self.time_delta_proj = nn.Linear(1, embed_dim)
# Periodic components (daily=86400s, weekly=604800s, etc.)
periods = [3600, 86400, 604800, 2592000][:max_periods]
self.register_buffer('periods', torch.tensor(periods, dtype=torch.float32))
self.periodic_proj = nn.Linear(max_periods * 2, embed_dim) # sin + cos
# Learnable position encoding for sequence order
self.layernorm = nn.LayerNorm(embed_dim)
def forward(self, timestamps: torch.Tensor) -> torch.Tensor:
"""
Args:
timestamps: (batch, seq_len) absolute timestamps in seconds
Returns:
temporal_emb: (batch, seq_len, embed_dim)
"""
B, T = timestamps.shape
# 1. Relative time deltas (seconds since previous interaction)
time_deltas = torch.zeros_like(timestamps)
time_deltas[:, 1:] = timestamps[:, 1:] - timestamps[:, :-1]
time_deltas = time_deltas.clamp(min=0)
# Log-scale for better numerical properties
log_deltas = torch.log1p(time_deltas).unsqueeze(-1) # (B, T, 1)
delta_emb = self.time_delta_proj(log_deltas) # (B, T, D)
# 2. Periodic components
ts_expanded = timestamps.unsqueeze(-1) # (B, T, 1)
periods = self.periods.view(1, 1, -1) # (1, 1, P)
angles = 2 * math.pi * ts_expanded / periods # (B, T, P)
periodic_features = torch.cat([
torch.sin(angles),
torch.cos(angles)
], dim=-1) # (B, T, 2*P)
periodic_emb = self.periodic_proj(periodic_features) # (B, T, D)
# 3. Combine
temporal_emb = self.layernorm(delta_emb + periodic_emb)
return temporal_emb
class TADNLayer(nn.Module):
"""Temporal-Aware Delta Network Layer.
Linear complexity O(n) recurrent layer with:
- Delta rule state updates (inspired by HyTRec)
- Explicit temporal decay gating
- Input-dependent selective memory
The state matrix S is updated as:
S_t = S_{t-1} * (I - g_t * beta_t * k_t * k_t^T) + beta_t * v_t * k_t^T
where g_t incorporates temporal decay:
g_t = alpha * sigmoid(W_g * [h_t, delta_h_t]) * tau_t + (1-alpha) * g_static
tau_t = exp(-(t_current - t_behavior) / T)
"""
def __init__(self, embed_dim: int, state_dim: int = 64, dropout: float = 0.1):
super().__init__()
self.embed_dim = embed_dim
self.state_dim = state_dim
# Query, Key, Value projections
self.q_proj = nn.Linear(embed_dim, state_dim)
self.k_proj = nn.Linear(embed_dim, state_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Gating mechanism
self.gate_proj = nn.Linear(embed_dim * 2, embed_dim)
self.beta_proj = nn.Linear(embed_dim, state_dim)
# Temporal decay parameters
self.alpha = nn.Parameter(torch.tensor(0.5))
self.time_scale = nn.Parameter(torch.tensor(1.0))
# Static gate (learnable baseline)
self.gate_static = nn.Parameter(torch.ones(embed_dim) * 0.5)
# Output
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.layernorm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
timestamps: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, embed_dim) input sequence
timestamps: (batch, seq_len) timestamps for temporal decay
mask: (batch, seq_len) boolean mask (True = valid)
Returns:
output: (batch, seq_len, embed_dim)
"""
B, T, D = x.shape
# Project to Q, K, V
q = self.q_proj(x) # (B, T, state_dim)
k = self.k_proj(x) # (B, T, state_dim)
v = self.v_proj(x) # (B, T, D)
# Beta (key importance scaling)
beta = torch.sigmoid(self.beta_proj(x)) # (B, T, state_dim)
# Temporal decay
if timestamps is not None:
# Compute recency-based decay with proper normalization
# Use the LAST VALID position's timestamp as reference
# Normalize by log(1 + delta) to handle large time ranges (seconds → years)
t_last = timestamps[:, -1:].unsqueeze(-1) # (B, 1, 1) - last timestamp
t_behavior = timestamps.unsqueeze(-1) # (B, T, 1)
time_delta = (t_last - t_behavior).clamp(min=0)
# Log-normalize: log(1 + delta_seconds / 3600) → hours-scale
log_delta = torch.log1p(time_delta / 3600.0) # Normalize to hours
# Learnable time scale controls the decay rate
tau = torch.exp(
-log_delta / (torch.abs(self.time_scale) * 10.0 + 1.0)
) # (B, T, 1), values in [0, 1]
else:
# Fallback: linear decay
positions = torch.arange(T, device=x.device).float()
tau = torch.exp(-positions / (T + 1e-6)).view(1, T, 1)
# Dynamic gating with temporal awareness
# Delta of hidden states for change detection
x_shifted = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1)
delta_x = x - x_shifted
gate_input = torch.cat([x, delta_x], dim=-1) # (B, T, 2*D)
alpha = torch.sigmoid(self.alpha)
g_dynamic = torch.sigmoid(self.gate_proj(gate_input)) # (B, T, D)
g = alpha * g_dynamic * tau + (1 - alpha) * torch.sigmoid(self.gate_static)
# Recurrent state update with delta rule
# Use chunked processing for better GPU utilization
chunk_size = min(64, T) # Process in chunks for efficiency
outputs = []
S = torch.zeros(B, self.state_dim, D, device=x.device) # State matrix
for chunk_start in range(0, T, chunk_size):
chunk_end = min(chunk_start + chunk_size, T)
for t in range(chunk_start, chunk_end):
k_t = k[:, t] # (B, state_dim)
v_t = v[:, t] # (B, D)
beta_t = beta[:, t] # (B, state_dim)
g_t = g[:, t] # (B, D)
q_t = q[:, t] # (B, state_dim)
# Delta rule update: erase old, write new
# Clamp erase to [0, 1] for stability
erase = torch.einsum('bs,bd->bsd', beta_t * k_t, g_t).clamp(0, 1)
write = torch.einsum('bs,bd->bsd', beta_t * k_t, v_t)
if mask is not None:
valid = mask[:, t].float().view(B, 1, 1)
S = S * (1 - erase * valid) + write * valid
else:
S = S * (1 - erase) + write
# Clamp state for numerical stability
S = S.clamp(-10, 10)
# Read from state
out_t = torch.einsum('bs,bsd->bd', q_t, S)
outputs.append(out_t)
output = torch.stack(outputs, dim=1) # (B, T, D)
output = self.out_proj(self.dropout(output))
output = self.layernorm(x + output) # Residual connection
return output
class CompressiveMemory(nn.Module):
"""Compressive Memory Module.
Compresses long sequence history into a fixed number of memory tokens.
Acts as information bottleneck (denoising per Rec2PM theory).
Uses cross-attention: memory queries attend to sequence to extract summary.
"""
def __init__(self, embed_dim: int, num_memory_tokens: int = 8, num_heads: int = 2):
super().__init__()
self.num_memory_tokens = num_memory_tokens
# Learnable memory query tokens
self.memory_queries = nn.Parameter(
torch.randn(num_memory_tokens, embed_dim) * 0.02
)
# Cross-attention: memory queries attend to sequence
self.cross_attn = nn.MultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
batch_first=True,
dropout=0.1
)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(embed_dim * 4, embed_dim),
nn.Dropout(0.1),
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(
self,
sequence: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
sequence: (batch, seq_len, embed_dim) - encoded sequence
mask: (batch, seq_len) boolean mask (True = valid, False = padding)
Returns:
memory: (batch, num_memory_tokens, embed_dim)
"""
B = sequence.shape[0]
# Expand memory queries for batch
queries = self.memory_queries.unsqueeze(0).expand(B, -1, -1) # (B, M, D)
# Cross-attention with key padding mask
# nn.MultiheadAttention expects key_padding_mask where True = IGNORE
if mask is not None:
key_padding_mask = ~mask # Invert: True means padding (to ignore)
else:
key_padding_mask = None
attn_out, _ = self.cross_attn(
query=queries,
key=sequence,
value=sequence,
key_padding_mask=key_padding_mask
)
memory = self.norm1(queries + attn_out)
memory = self.norm2(memory + self.ffn(memory))
return memory
class ShortTermAttention(nn.Module):
"""Standard self-attention block for short-term (recent) interactions.
Uses standard causal multi-head attention — full expressiveness
for the most recent K items where O(K²) is acceptable.
"""
def __init__(self, embed_dim: int, num_heads: int = 2, num_layers: int = 2, dropout: float = 0.1):
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x: (batch, K, embed_dim) recent interactions
mask: (batch, K) boolean mask
Returns:
output: (batch, K, embed_dim)
"""
T = x.shape[1]
# Causal mask
causal_mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
)
# Padding mask
src_key_padding_mask = ~mask if mask is not None else None
output = self.encoder(
x,
mask=causal_mask,
src_key_padding_mask=src_key_padding_mask
)
return output
class AdaptiveFusionGate(nn.Module):
"""Adaptive fusion gate that balances long-term and short-term signals.
Per-user, per-timestep gating:
output = sigma(gate) * long_term + (1 - sigma(gate)) * short_term
"""
def __init__(self, embed_dim: int):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(embed_dim * 3, embed_dim),
nn.GELU(),
nn.Linear(embed_dim, embed_dim),
nn.Sigmoid()
)
def forward(
self,
long_term: torch.Tensor,
short_term: torch.Tensor,
memory: torch.Tensor
) -> torch.Tensor:
"""
Args:
long_term: (batch, embed_dim)
short_term: (batch, embed_dim)
memory: (batch, embed_dim) compressed memory summary
Returns:
fused: (batch, embed_dim)
"""
gate_input = torch.cat([long_term, short_term, memory], dim=-1)
g = self.gate(gate_input)
return g * long_term + (1 - g) * short_term
class MARS(nn.Module):
"""
MARS: Multi-scale Adaptive Recurrence with State compression
Architecture:
Input: Full user interaction sequence + timestamps
|
v
[Item Embedding + Temporal Encoding]
|
+---- Long-term Branch (TADN layers, O(n))
| |
| [Compressive Memory] → memory tokens
| |
+---- Short-term Branch (Self-Attention on recent K items)
|
v
[Adaptive Fusion Gate]
|
v
[Prediction Head] → next item scores
Args:
num_items: number of unique items
embed_dim: embedding dimension
max_seq_len: maximum sequence length (can be very long, e.g. 2048)
short_term_len: number of recent items for short-term branch
num_memory_tokens: number of compressive memory tokens
num_tadn_layers: number of TADN layers in long-term branch
num_attn_layers: number of attention layers in short-term branch
num_heads: number of attention heads
state_dim: state dimension for TADN
dropout: dropout rate
"""
def __init__(
self,
num_items: int,
embed_dim: int = 64,
max_seq_len: int = 512,
short_term_len: int = 50,
num_memory_tokens: int = 8,
num_tadn_layers: int = 3,
num_attn_layers: int = 2,
num_heads: int = 2,
state_dim: int = 64,
dropout: float = 0.1,
):
super().__init__()
self.num_items = num_items
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.short_term_len = short_term_len
self.num_memory_tokens = num_memory_tokens
# Item embeddings (0 = padding)
self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0)
# Temporal encoding
self.temporal_encoding = TemporalEncoding(embed_dim)
# Learnable position encoding (for short-term branch)
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
# Input processing
self.input_norm = nn.LayerNorm(embed_dim)
self.input_dropout = nn.Dropout(dropout)
# Long-term branch: stack of TADN layers
self.tadn_layers = nn.ModuleList([
TADNLayer(embed_dim, state_dim, dropout)
for _ in range(num_tadn_layers)
])
# Compressive memory
self.compressive_memory = CompressiveMemory(
embed_dim, num_memory_tokens, num_heads
)
# Short-term branch: standard self-attention
self.short_term_attn = ShortTermAttention(
embed_dim, num_heads, num_attn_layers, dropout
)
# Adaptive fusion
self.fusion_gate = AdaptiveFusionGate(embed_dim)
# Output projection
self.output_norm = nn.LayerNorm(embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize with truncated normal distribution."""
for name, param in self.named_parameters():
if 'weight' in name and param.dim() >= 2:
nn.init.trunc_normal_(param, std=0.02)
elif 'bias' in name:
nn.init.zeros_(param)
# Special init for item embeddings
nn.init.trunc_normal_(self.item_embedding.weight, std=0.02)
nn.init.zeros_(self.item_embedding.weight[0]) # Padding = zero
@property
def item_embeddings(self):
"""Access item embedding table (for evaluation)."""
return self.item_embedding
def encode(
self,
item_ids: torch.Tensor,
timestamps: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Encode a full sequence into user representations.
Args:
item_ids: (batch, seq_len) item indices (0 = padding)
timestamps: (batch, seq_len) timestamps in seconds
mask: (batch, seq_len) boolean mask (True = valid)
Returns:
user_emb: (batch, embed_dim) final user representation
"""
B, T = item_ids.shape
# Create mask from padding if not provided
if mask is None:
mask = (item_ids != 0)
# 1. Item + Temporal Embeddings
item_emb = self.item_embedding(item_ids) # (B, T, D)
if timestamps is not None:
temp_emb = self.temporal_encoding(timestamps.float())
item_emb = item_emb + temp_emb
# Add position embeddings (only for the sequence order)
positions = torch.arange(T, device=item_ids.device).unsqueeze(0)
positions = positions.clamp(max=self.max_seq_len - 1)
pos_emb = self.position_embedding(positions)
item_emb = self.input_norm(item_emb + pos_emb)
item_emb = self.input_dropout(item_emb)
# 2. Long-term Branch: TADN over full sequence
long_term_repr = item_emb
for tadn in self.tadn_layers:
long_term_repr = tadn(long_term_repr, timestamps, mask)
# Compress long-term into memory tokens
memory = self.compressive_memory(long_term_repr, mask) # (B, M, D)
memory_summary = memory.mean(dim=1) # (B, D) - aggregated memory
# Get last valid long-term representation
# Use mask to find last valid position
lengths = mask.sum(dim=1).long() # (B,)
long_term_last = long_term_repr[
torch.arange(B, device=item_ids.device),
(lengths - 1).clamp(min=0)
] # (B, D)
# 3. Short-term Branch: Attention on recent K items
# With right-padding, valid items are at positions 0...(length-1)
# Extract last K valid items per user
K = min(self.short_term_len, T)
# For each user, get the last K valid positions
short_item_ids_list = []
short_ts_list = []
short_mask_list = []
for b in range(B):
seq_len = lengths[b].item()
actual_k = min(K, seq_len)
start = max(0, seq_len - K)
end = seq_len
# Extract valid items and pad to K
ids = item_ids[b, start:end]
pad_len = K - actual_k
if pad_len > 0:
ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype, device=ids.device)])
short_item_ids_list.append(ids)
if timestamps is not None:
ts = timestamps[b, start:end]
if pad_len > 0:
ts = torch.cat([ts, torch.zeros(pad_len, dtype=ts.dtype, device=ts.device)])
short_ts_list.append(ts)
m = torch.zeros(K, dtype=torch.bool, device=item_ids.device)
m[:actual_k] = True
short_mask_list.append(m)
short_item_ids = torch.stack(short_item_ids_list) # (B, K)
short_mask = torch.stack(short_mask_list) # (B, K)
short_emb = self.item_embedding(short_item_ids)
if timestamps is not None:
short_ts = torch.stack(short_ts_list) # (B, K)
short_temp = self.temporal_encoding(short_ts.float())
short_emb = short_emb + short_temp
short_positions = torch.arange(K, device=item_ids.device).unsqueeze(0)
short_positions = short_positions.clamp(max=self.max_seq_len - 1)
short_emb = short_emb + self.position_embedding(short_positions)
short_emb = self.input_norm(short_emb)
short_term_repr = self.short_term_attn(short_emb, short_mask)
# Get last valid short-term representation
short_lengths = short_mask.sum(dim=1).long()
short_term_last = short_term_repr[
torch.arange(B, device=item_ids.device),
(short_lengths - 1).clamp(min=0)
] # (B, D)
# 4. Adaptive Fusion
user_emb = self.fusion_gate(long_term_last, short_term_last, memory_summary)
user_emb = self.output_proj(self.output_norm(user_emb))
return user_emb
def forward(
self,
batch: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Training forward pass.
Expected batch format (flat tensors, matching Yambda convention):
- item_ids: (batch, max_seq_len) padded item sequences
- timestamps: (batch, max_seq_len) padded timestamps
- mask: (batch, max_seq_len) boolean mask
- positive_ids: (batch,) positive next items
- negative_ids: (batch, num_neg) negative items
Returns:
loss: scalar BCE loss
"""
if self.training:
return self._training_forward(batch)
else:
return self._eval_forward(batch)
def _training_forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute training loss with next-item prediction."""
item_ids = batch['item_ids'] # (B, T)
timestamps = batch.get('timestamps') # (B, T) or None
mask = batch.get('mask') # (B, T)
pos_ids = batch['positive_ids'] # (B,)
neg_ids = batch['negative_ids'] # (B, num_neg)
# Encode user sequence
user_emb = self.encode(item_ids, timestamps, mask) # (B, D)
# Score positive and negative items
pos_emb = self.item_embedding(pos_ids) # (B, D)
neg_emb = self.item_embedding(neg_ids) # (B, num_neg, D)
pos_scores = (user_emb * pos_emb).sum(dim=-1) # (B,)
neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) # (B, num_neg)
# BPR-style loss + BCE
pos_labels = torch.ones_like(pos_scores)
neg_labels = torch.zeros_like(neg_scores)
loss_pos = F.binary_cross_entropy_with_logits(pos_scores, pos_labels)
loss_neg = F.binary_cross_entropy_with_logits(neg_scores, neg_labels)
loss = loss_pos + loss_neg
return loss
def _eval_forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Eval forward: returns user embeddings."""
item_ids = batch['item_ids']
timestamps = batch.get('timestamps')
mask = batch.get('mask')
user_emb = self.encode(item_ids, timestamps, mask)
return user_emb
class SASRecBaseline(nn.Module):
"""
Standard SASRec baseline for comparison.
Uses causal self-attention (O(n²) complexity).
"""
def __init__(
self,
num_items: int,
embed_dim: int = 64,
max_seq_len: int = 200,
num_heads: int = 2,
num_layers: int = 2,
dropout: float = 0.1,
):
super().__init__()
self.num_items = num_items
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.item_embedding = nn.Embedding(num_items + 1, embed_dim, padding_idx=0)
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
self.input_norm = nn.LayerNorm(embed_dim)
self.input_dropout = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_norm = nn.LayerNorm(embed_dim)
self._init_weights()
def _init_weights(self):
for name, param in self.named_parameters():
if 'weight' in name and param.dim() >= 2:
nn.init.trunc_normal_(param, std=0.02)
elif 'bias' in name:
nn.init.zeros_(param)
nn.init.zeros_(self.item_embedding.weight[0])
@property
def item_embeddings(self):
return self.item_embedding
def encode(self, item_ids, timestamps=None, mask=None):
B, T = item_ids.shape
if mask is None:
mask = (item_ids != 0)
item_emb = self.item_embedding(item_ids)
positions = torch.arange(T, device=item_ids.device).unsqueeze(0)
positions = positions.clamp(max=self.max_seq_len - 1)
item_emb = item_emb + self.position_embedding(positions)
item_emb = self.input_norm(item_emb)
item_emb = self.input_dropout(item_emb)
causal_mask = torch.triu(torch.ones(T, T, device=item_ids.device, dtype=torch.bool), diagonal=1)
src_key_padding_mask = ~mask
output = self.encoder(item_emb, mask=causal_mask, src_key_padding_mask=src_key_padding_mask)
lengths = mask.sum(dim=1).long()
user_emb = output[torch.arange(B, device=item_ids.device), (lengths - 1).clamp(min=0)]
user_emb = self.output_norm(user_emb)
return user_emb
def forward(self, batch):
if self.training:
item_ids = batch['item_ids']
timestamps = batch.get('timestamps')
mask = batch.get('mask')
pos_ids = batch['positive_ids']
neg_ids = batch['negative_ids']
user_emb = self.encode(item_ids, timestamps, mask)
pos_emb = self.item_embedding(pos_ids)
neg_emb = self.item_embedding(neg_ids)
pos_scores = (user_emb * pos_emb).sum(dim=-1)
neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb)
loss_pos = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
loss_neg = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
return loss_pos + loss_neg
else:
return self.encode(batch['item_ids'], batch.get('timestamps'), batch.get('mask'))