ymlin105's picture
feat(v2.5): ItemCF direction weight, Swing recall, LGBMRanker
fe617ac
import torch
import torch.nn as nn
import numpy as np
class SASRec(nn.Module):
def __init__(self, num_items, max_len=50, hidden_dim=64, num_blocks=2, num_heads=2, dropout_rate=0.2):
super(SASRec, self).__init__()
self.num_items = num_items
self.max_len = max_len
self.hidden_dim = hidden_dim
# Embeddings
# Item embedding (0 is padding)
self.item_emb = nn.Embedding(num_items + 1, hidden_dim, padding_idx=0)
# Positional embedding
self.pos_emb = nn.Embedding(max_len, hidden_dim)
self.emb_dropout = nn.Dropout(dropout_rate)
# Transformer Blocks
# Standard PyTorch TransformerEncoder is easiest
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim*4,
dropout=dropout_rate,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_blocks)
self.last_layernorm = nn.LayerNorm(hidden_dim)
def forward(self, input_seqs):
# input_seqs: [batch_size, max_len]
batch_size = input_seqs.shape[0]
device = input_seqs.device
# 1. Generate Masks
# Padding mask: Ignore 0s
padding_mask = (input_seqs == 0) # [batch, len]
# Causal mask: Future positions cannot see current
# [len, len]
sz = input_seqs.shape[1]
# triu=1 means upper triangle is 1 (masked).
# But PyTorch transformer src_mask: -inf for masked, 0 for allowed.
causal_mask = torch.triu(torch.ones(sz, sz, device=device) * float('-inf'), diagonal=1)
# 2. Embedding
seqs = self.item_emb(input_seqs) # [batch, len, dim]
seqs = seqs * (self.hidden_dim ** 0.5) # Scale
# Add position
positions = torch.arange(sz, device=device).unsqueeze(0).repeat(batch_size, 1)
# Positions should only be valid for non-padding?
# Standard SASRec uses absolute input position.
pos_emb = self.pos_emb(positions)
x = seqs + pos_emb
x = self.emb_dropout(x)
# 3. Transformer
# Remove src_key_padding_mask to avoid NaN when all attention keys are masked
# It's a trade-off: model might attend to padding, but it won't crash.
# Since we mask the Loss, it should be fine.
output = self.transformer(x, mask=causal_mask)
output = self.last_layernorm(output)
return output # [batch, len, dim]
def predict(self, input_seqs, candidate_items):
# input_seqs: [batch, len]
# candidate_items: [batch] (Single item) or [batch, num_candidates]
# We only care about the specific output at the LAST valid position
# But for batch efficiency in training, we usually use full sequence.
# For inference (which is what we care about for features):
# Get the embedding of the LAST token in the sequence
seq_output = self.forward(input_seqs)
# Take the last element: [batch, dim]
# Note: If padding exists at end (which shouldn't happen with our left-pad logic,
# but actually we usually pad at LEFT or RIGHT?
# SASRec standard is usually Pad Left? Or just simple pad.
# Let's assume input_seqs is valid till end or find last non-zero.
# Simplified: take last hidden state
final_state = seq_output[:, -1, :]
return final_state