File size: 7,006 Bytes
2319f81 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | """
Evaluation module for sequential recommendation models.
Metrics:
- HR@K (Hit Rate): Whether the ground truth item appears in top-K
- NDCG@K (Normalized Discounted Cumulative Gain): Position-aware ranking quality
- MRR@K (Mean Reciprocal Rank): Reciprocal of the rank of the first correct item
"""
import torch
import numpy as np
from typing import Dict, List
import time
@torch.no_grad()
def evaluate_model(
model,
eval_loader,
num_items: int,
device: torch.device,
ks: List[int] = [5, 10, 20, 50],
full_ranking: bool = False,
) -> Dict[str, float]:
"""
Evaluate a sequential recommendation model.
Uses sampled metrics by default (positive + negatives from batch).
Set full_ranking=True for ranking against all items (slow but accurate).
Args:
model: trained model
eval_loader: evaluation DataLoader
num_items: total number of items
device: torch device
ks: list of K values for metrics
full_ranking: if True, rank against all items
Returns:
dict of metrics: HR@K, NDCG@K, MRR@K
"""
model.eval()
all_hrs = {k: [] for k in ks}
all_ndcgs = {k: [] for k in ks}
all_mrrs = {k: [] for k in ks}
start_time = time.time()
for batch in eval_loader:
batch_device = {k: v.to(device) for k, v in batch.items()}
# Get user embeddings
user_emb = model(batch_device) # (B, D)
# Get item embeddings
pos_ids = batch_device['positive_ids'] # (B,)
neg_ids = batch_device['negative_ids'] # (B, num_neg)
if full_ranking:
# Rank against ALL items
all_item_embs = model.item_embeddings.weight[1:] # Skip padding (0)
scores = torch.matmul(user_emb, all_item_embs.t()) # (B, num_items)
# Ground truth indices (0-indexed)
gt_indices = pos_ids - 1 # Convert from 1-indexed to 0-indexed
for i in range(scores.shape[0]):
gt_score = scores[i, gt_indices[i]]
rank = (scores[i] > gt_score).sum().item() + 1
for k in ks:
all_hrs[k].append(1.0 if rank <= k else 0.0)
all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0)
all_mrrs[k].append(1.0 / rank if rank <= k else 0.0)
else:
# Sampled ranking: positive + negatives
pos_emb = model.item_embeddings(pos_ids) # (B, D)
neg_emb = model.item_embeddings(neg_ids) # (B, num_neg, D)
pos_scores = (user_emb * pos_emb).sum(dim=-1, keepdim=True) # (B, 1)
neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) # (B, num_neg)
# All scores: positive first, then negatives
all_scores = torch.cat([pos_scores, neg_scores], dim=1) # (B, 1+num_neg)
# Rank of positive item (0-indexed)
ranks = (all_scores > pos_scores).sum(dim=1) + 1 # (B,)
for k in ks:
hits = (ranks <= k).float()
ndcgs = torch.where(
ranks <= k,
1.0 / torch.log2(ranks.float() + 1),
torch.zeros_like(ranks.float())
)
mrrs = torch.where(
ranks <= k,
1.0 / ranks.float(),
torch.zeros_like(ranks.float())
)
all_hrs[k].extend(hits.cpu().tolist())
all_ndcgs[k].extend(ndcgs.cpu().tolist())
all_mrrs[k].extend(mrrs.cpu().tolist())
eval_time = time.time() - start_time
metrics = {}
for k in ks:
metrics[f'HR@{k}'] = np.mean(all_hrs[k])
metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k])
metrics[f'MRR@{k}'] = np.mean(all_mrrs[k])
metrics['eval_time'] = eval_time
return metrics
@torch.no_grad()
def compute_metrics_full(
model,
eval_data,
num_items: int,
device: torch.device,
max_seq_len: int = 512,
ks: List[int] = [5, 10, 20, 50],
batch_size: int = 256,
) -> Dict[str, float]:
"""
Full-ranking evaluation (all items).
More accurate but slower.
"""
model.eval()
# Get all item embeddings
all_item_embs = model.item_embeddings.weight[1:].to(device) # (num_items, D)
all_hrs = {k: [] for k in ks}
all_ndcgs = {k: [] for k in ks}
for i in range(0, len(eval_data), batch_size):
batch_data = eval_data[i:i+batch_size]
# Prepare batch
max_len = min(max(len(d['item_ids']) for d in batch_data), max_seq_len)
item_ids_batch = []
mask_batch = []
gt_items = []
for d in batch_data:
ids = d['item_ids'][-max_len:]
pad_len = max_len - len(ids)
item_ids_batch.append([0] * pad_len + ids)
mask_batch.append([False] * pad_len + [True] * len(ids))
gt_items.append(d['next_item'])
batch = {
'item_ids': torch.tensor(item_ids_batch, dtype=torch.long, device=device),
'mask': torch.tensor(mask_batch, dtype=torch.bool, device=device),
}
# Encode
user_emb = model(batch) # (B, D)
# Score all items
scores = torch.matmul(user_emb, all_item_embs.t()) # (B, num_items)
# Compute metrics
for j, gt in enumerate(gt_items):
gt_idx = gt - 1 # 0-indexed
gt_score = scores[j, gt_idx]
rank = (scores[j] > gt_score).sum().item() + 1
for k in ks:
all_hrs[k].append(1.0 if rank <= k else 0.0)
all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0)
metrics = {}
for k in ks:
metrics[f'HR@{k}'] = np.mean(all_hrs[k])
metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k])
return metrics
def print_comparison(mars_results: Dict, sasrec_results: Dict, ks: List[int] = [5, 10, 20]):
"""Pretty-print comparison between MARS and SASRec."""
print(f"\n{'='*70}")
print(f"{'Metric':<15} | {'MARS':>10} | {'SASRec':>10} | {'Δ':>10} | {'Δ%':>10}")
print(f"{'-'*70}")
for k in ks:
for metric_name in [f'HR@{k}', f'NDCG@{k}', f'MRR@{k}']:
mars_val = mars_results.get(metric_name, 0)
sasrec_val = sasrec_results.get(metric_name, 0)
delta = mars_val - sasrec_val
delta_pct = (delta / sasrec_val * 100) if sasrec_val > 0 else 0
marker = '↑' if delta > 0 else '↓' if delta < 0 else '='
print(f"{metric_name:<15} | {mars_val:>10.4f} | {sasrec_val:>10.4f} | "
f"{delta:>+10.4f} | {marker} {abs(delta_pct):>7.2f}%")
print(f"{'='*70}")
|