""" Evaluation module for sequential recommendation models. Metrics: - HR@K (Hit Rate): Whether the ground truth item appears in top-K - NDCG@K (Normalized Discounted Cumulative Gain): Position-aware ranking quality - MRR@K (Mean Reciprocal Rank): Reciprocal of the rank of the first correct item """ import torch import numpy as np from typing import Dict, List import time @torch.no_grad() def evaluate_model( model, eval_loader, num_items: int, device: torch.device, ks: List[int] = [5, 10, 20, 50], full_ranking: bool = False, ) -> Dict[str, float]: """ Evaluate a sequential recommendation model. Uses sampled metrics by default (positive + negatives from batch). Set full_ranking=True for ranking against all items (slow but accurate). Args: model: trained model eval_loader: evaluation DataLoader num_items: total number of items device: torch device ks: list of K values for metrics full_ranking: if True, rank against all items Returns: dict of metrics: HR@K, NDCG@K, MRR@K """ model.eval() all_hrs = {k: [] for k in ks} all_ndcgs = {k: [] for k in ks} all_mrrs = {k: [] for k in ks} start_time = time.time() for batch in eval_loader: batch_device = {k: v.to(device) for k, v in batch.items()} # Get user embeddings user_emb = model(batch_device) # (B, D) # Get item embeddings pos_ids = batch_device['positive_ids'] # (B,) neg_ids = batch_device['negative_ids'] # (B, num_neg) if full_ranking: # Rank against ALL items all_item_embs = model.item_embeddings.weight[1:] # Skip padding (0) scores = torch.matmul(user_emb, all_item_embs.t()) # (B, num_items) # Ground truth indices (0-indexed) gt_indices = pos_ids - 1 # Convert from 1-indexed to 0-indexed for i in range(scores.shape[0]): gt_score = scores[i, gt_indices[i]] rank = (scores[i] > gt_score).sum().item() + 1 for k in ks: all_hrs[k].append(1.0 if rank <= k else 0.0) all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0) all_mrrs[k].append(1.0 / rank if rank <= k else 0.0) else: # Sampled ranking: positive + negatives pos_emb = model.item_embeddings(pos_ids) # (B, D) neg_emb = model.item_embeddings(neg_ids) # (B, num_neg, D) pos_scores = (user_emb * pos_emb).sum(dim=-1, keepdim=True) # (B, 1) neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb) # (B, num_neg) # All scores: positive first, then negatives all_scores = torch.cat([pos_scores, neg_scores], dim=1) # (B, 1+num_neg) # Rank of positive item (0-indexed) ranks = (all_scores > pos_scores).sum(dim=1) + 1 # (B,) for k in ks: hits = (ranks <= k).float() ndcgs = torch.where( ranks <= k, 1.0 / torch.log2(ranks.float() + 1), torch.zeros_like(ranks.float()) ) mrrs = torch.where( ranks <= k, 1.0 / ranks.float(), torch.zeros_like(ranks.float()) ) all_hrs[k].extend(hits.cpu().tolist()) all_ndcgs[k].extend(ndcgs.cpu().tolist()) all_mrrs[k].extend(mrrs.cpu().tolist()) eval_time = time.time() - start_time metrics = {} for k in ks: metrics[f'HR@{k}'] = np.mean(all_hrs[k]) metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k]) metrics[f'MRR@{k}'] = np.mean(all_mrrs[k]) metrics['eval_time'] = eval_time return metrics @torch.no_grad() def compute_metrics_full( model, eval_data, num_items: int, device: torch.device, max_seq_len: int = 512, ks: List[int] = [5, 10, 20, 50], batch_size: int = 256, ) -> Dict[str, float]: """ Full-ranking evaluation (all items). More accurate but slower. """ model.eval() # Get all item embeddings all_item_embs = model.item_embeddings.weight[1:].to(device) # (num_items, D) all_hrs = {k: [] for k in ks} all_ndcgs = {k: [] for k in ks} for i in range(0, len(eval_data), batch_size): batch_data = eval_data[i:i+batch_size] # Prepare batch max_len = min(max(len(d['item_ids']) for d in batch_data), max_seq_len) item_ids_batch = [] mask_batch = [] gt_items = [] for d in batch_data: ids = d['item_ids'][-max_len:] pad_len = max_len - len(ids) item_ids_batch.append([0] * pad_len + ids) mask_batch.append([False] * pad_len + [True] * len(ids)) gt_items.append(d['next_item']) batch = { 'item_ids': torch.tensor(item_ids_batch, dtype=torch.long, device=device), 'mask': torch.tensor(mask_batch, dtype=torch.bool, device=device), } # Encode user_emb = model(batch) # (B, D) # Score all items scores = torch.matmul(user_emb, all_item_embs.t()) # (B, num_items) # Compute metrics for j, gt in enumerate(gt_items): gt_idx = gt - 1 # 0-indexed gt_score = scores[j, gt_idx] rank = (scores[j] > gt_score).sum().item() + 1 for k in ks: all_hrs[k].append(1.0 if rank <= k else 0.0) all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0) metrics = {} for k in ks: metrics[f'HR@{k}'] = np.mean(all_hrs[k]) metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k]) return metrics def print_comparison(mars_results: Dict, sasrec_results: Dict, ks: List[int] = [5, 10, 20]): """Pretty-print comparison between MARS and SASRec.""" print(f"\n{'='*70}") print(f"{'Metric':<15} | {'MARS':>10} | {'SASRec':>10} | {'Δ':>10} | {'Δ%':>10}") print(f"{'-'*70}") for k in ks: for metric_name in [f'HR@{k}', f'NDCG@{k}', f'MRR@{k}']: mars_val = mars_results.get(metric_name, 0) sasrec_val = sasrec_results.get(metric_name, 0) delta = mars_val - sasrec_val delta_pct = (delta / sasrec_val * 100) if sasrec_val > 0 else 0 marker = '↑' if delta > 0 else '↓' if delta < 0 else '=' print(f"{metric_name:<15} | {mars_val:>10.4f} | {sasrec_val:>10.4f} | " f"{delta:>+10.4f} | {marker} {abs(delta_pct):>7.2f}%") print(f"{'='*70}")