File size: 7,006 Bytes
2319f81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""
Evaluation module for sequential recommendation models.

Metrics:
- HR@K (Hit Rate): Whether the ground truth item appears in top-K
- NDCG@K (Normalized Discounted Cumulative Gain): Position-aware ranking quality
- MRR@K (Mean Reciprocal Rank): Reciprocal of the rank of the first correct item
"""

import torch
import numpy as np
from typing import Dict, List
import time


@torch.no_grad()
def evaluate_model(
    model,
    eval_loader,
    num_items: int,
    device: torch.device,
    ks: List[int] = [5, 10, 20, 50],
    full_ranking: bool = False,
) -> Dict[str, float]:
    """
    Evaluate a sequential recommendation model.
    
    Uses sampled metrics by default (positive + negatives from batch).
    Set full_ranking=True for ranking against all items (slow but accurate).
    
    Args:
        model: trained model
        eval_loader: evaluation DataLoader
        num_items: total number of items
        device: torch device
        ks: list of K values for metrics
        full_ranking: if True, rank against all items
    
    Returns:
        dict of metrics: HR@K, NDCG@K, MRR@K
    """
    model.eval()
    
    all_hrs = {k: [] for k in ks}
    all_ndcgs = {k: [] for k in ks}
    all_mrrs = {k: [] for k in ks}
    
    start_time = time.time()
    
    for batch in eval_loader:
        batch_device = {k: v.to(device) for k, v in batch.items()}
        
        # Get user embeddings
        user_emb = model(batch_device)  # (B, D)
        
        # Get item embeddings
        pos_ids = batch_device['positive_ids']   # (B,)
        neg_ids = batch_device['negative_ids']   # (B, num_neg)
        
        if full_ranking:
            # Rank against ALL items
            all_item_embs = model.item_embeddings.weight[1:]  # Skip padding (0)
            scores = torch.matmul(user_emb, all_item_embs.t())  # (B, num_items)
            
            # Ground truth indices (0-indexed)
            gt_indices = pos_ids - 1  # Convert from 1-indexed to 0-indexed
            
            for i in range(scores.shape[0]):
                gt_score = scores[i, gt_indices[i]]
                rank = (scores[i] > gt_score).sum().item() + 1
                
                for k in ks:
                    all_hrs[k].append(1.0 if rank <= k else 0.0)
                    all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0)
                    all_mrrs[k].append(1.0 / rank if rank <= k else 0.0)
        else:
            # Sampled ranking: positive + negatives
            pos_emb = model.item_embeddings(pos_ids)  # (B, D)
            neg_emb = model.item_embeddings(neg_ids)  # (B, num_neg, D)
            
            pos_scores = (user_emb * pos_emb).sum(dim=-1, keepdim=True)  # (B, 1)
            neg_scores = torch.einsum('bd,bnd->bn', user_emb, neg_emb)   # (B, num_neg)
            
            # All scores: positive first, then negatives
            all_scores = torch.cat([pos_scores, neg_scores], dim=1)  # (B, 1+num_neg)
            
            # Rank of positive item (0-indexed)
            ranks = (all_scores > pos_scores).sum(dim=1) + 1  # (B,)
            
            for k in ks:
                hits = (ranks <= k).float()
                ndcgs = torch.where(
                    ranks <= k,
                    1.0 / torch.log2(ranks.float() + 1),
                    torch.zeros_like(ranks.float())
                )
                mrrs = torch.where(
                    ranks <= k,
                    1.0 / ranks.float(),
                    torch.zeros_like(ranks.float())
                )
                
                all_hrs[k].extend(hits.cpu().tolist())
                all_ndcgs[k].extend(ndcgs.cpu().tolist())
                all_mrrs[k].extend(mrrs.cpu().tolist())
    
    eval_time = time.time() - start_time
    
    metrics = {}
    for k in ks:
        metrics[f'HR@{k}'] = np.mean(all_hrs[k])
        metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k])
        metrics[f'MRR@{k}'] = np.mean(all_mrrs[k])
    
    metrics['eval_time'] = eval_time
    
    return metrics


@torch.no_grad()
def compute_metrics_full(
    model,
    eval_data,
    num_items: int,
    device: torch.device,
    max_seq_len: int = 512,
    ks: List[int] = [5, 10, 20, 50],
    batch_size: int = 256,
) -> Dict[str, float]:
    """
    Full-ranking evaluation (all items).
    More accurate but slower.
    """
    model.eval()
    
    # Get all item embeddings
    all_item_embs = model.item_embeddings.weight[1:].to(device)  # (num_items, D)
    
    all_hrs = {k: [] for k in ks}
    all_ndcgs = {k: [] for k in ks}
    
    for i in range(0, len(eval_data), batch_size):
        batch_data = eval_data[i:i+batch_size]
        
        # Prepare batch
        max_len = min(max(len(d['item_ids']) for d in batch_data), max_seq_len)
        
        item_ids_batch = []
        mask_batch = []
        gt_items = []
        
        for d in batch_data:
            ids = d['item_ids'][-max_len:]
            pad_len = max_len - len(ids)
            item_ids_batch.append([0] * pad_len + ids)
            mask_batch.append([False] * pad_len + [True] * len(ids))
            gt_items.append(d['next_item'])
        
        batch = {
            'item_ids': torch.tensor(item_ids_batch, dtype=torch.long, device=device),
            'mask': torch.tensor(mask_batch, dtype=torch.bool, device=device),
        }
        
        # Encode
        user_emb = model(batch)  # (B, D)
        
        # Score all items
        scores = torch.matmul(user_emb, all_item_embs.t())  # (B, num_items)
        
        # Compute metrics
        for j, gt in enumerate(gt_items):
            gt_idx = gt - 1  # 0-indexed
            gt_score = scores[j, gt_idx]
            rank = (scores[j] > gt_score).sum().item() + 1
            
            for k in ks:
                all_hrs[k].append(1.0 if rank <= k else 0.0)
                all_ndcgs[k].append(1.0 / np.log2(rank + 1) if rank <= k else 0.0)
    
    metrics = {}
    for k in ks:
        metrics[f'HR@{k}'] = np.mean(all_hrs[k])
        metrics[f'NDCG@{k}'] = np.mean(all_ndcgs[k])
    
    return metrics


def print_comparison(mars_results: Dict, sasrec_results: Dict, ks: List[int] = [5, 10, 20]):
    """Pretty-print comparison between MARS and SASRec."""
    
    print(f"\n{'='*70}")
    print(f"{'Metric':<15} | {'MARS':>10} | {'SASRec':>10} | {'Δ':>10} | {'Δ%':>10}")
    print(f"{'-'*70}")
    
    for k in ks:
        for metric_name in [f'HR@{k}', f'NDCG@{k}', f'MRR@{k}']:
            mars_val = mars_results.get(metric_name, 0)
            sasrec_val = sasrec_results.get(metric_name, 0)
            delta = mars_val - sasrec_val
            delta_pct = (delta / sasrec_val * 100) if sasrec_val > 0 else 0
            
            marker = '↑' if delta > 0 else '↓' if delta < 0 else '='
            print(f"{metric_name:<15} | {mars_val:>10.4f} | {sasrec_val:>10.4f} | "
                  f"{delta:>+10.4f} | {marker} {abs(delta_pct):>7.2f}%")
    
    print(f"{'='*70}")