File size: 6,114 Bytes
549c270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# src/utils/metrics.py

from __future__ import annotations
import csv, time, subprocess
from pathlib import Path
from typing import Dict, Any, List
import numpy as np
import json
import faiss


def _git_sha() -> str:
    try:
        return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip()
    except Exception:
        return "nogit"


def append_metrics_dict(row: Dict[str, Any], csv_path: str | Path = "logs/metrics.csv", no_log: bool = False):
    """
    Append a single row of metrics/metadata to logs/metrics.csv.
    - Automatically adds timestamp + git_sha if missing.
    - Creates header on first write.
    """
    if no_log:
        return
    path = Path(csv_path)
    path.parent.mkdir(parents=True, exist_ok=True)

    row = dict(row)  # shallow copy
    row.setdefault("timestamp", time.strftime("%Y-%m-%d %H:%M:%S"))
    row.setdefault("git_sha", _git_sha())

    write_header = not path.exists()
    with path.open("a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=list(row.keys()))
        if write_header:
            writer.writeheader()
        writer.writerow(row)


def append_metrics(dataset: str, method: str, hitk: float, ndcg: float):
    """
    Wrapper to log evaluation metrics in CSV format for CoVE experiments.
    This allows calling with 4 args instead of requiring a dictionary.
    """
    row = {
        "dataset": dataset,
        "method": method,
        "hit@k": hitk,
        "ndcg@k": ndcg,
    }
    append_metrics_dict(row)


def compute_metrics(preds: List[List[str]], labels: List[List[str]], k: int = 10):
    """
    Compute Hit@k and NDCG@k for FAISS-based predictions.
    """
    hit, ndcg = 0.0, 0.0
    for p, l in zip(preds, labels):
        target = l[0]
        if target in p[:k]:
            hit += 1.0
            index = p.index(target)
            ndcg += 1.0 / np.log2(index + 2)  # index starts at 0

    total = len(preds)
    return {
        "hit@k": hit / total,
        "ndcg@k": ndcg / total,
    }


def compute_metrics_cove(reranked: Dict[str, List[str]], true_next_items: List[str], k: int = 10):
    """
    For logits-based reranking. Expects:
    - reranked: dict of {user_idx: [item_ids...]}
    - true_next_items: list of ground truth item_ids in same order as user_idx
    """
    hit, ndcg = 0.0, 0.0
    total = len(true_next_items)

    for i, true_item in enumerate(true_next_items):
        pred_items = reranked.get(str(i), [])[:k]
        if true_item in pred_items:
            hit += 1.0
            index = pred_items.index(true_item)
            ndcg += 1.0 / np.log2(index + 2)

    return {
        "hit@k": hit / total,
        "ndcg@k": ndcg / total,
    }


def compute_hit_ndcg(sequences, scores, top_k=10):
    """
    Computes Hit@K and NDCG@K in a single pass.
    """
    hit, ndcg_total = 0.0, 0.0
    total = 0

    def dcg(relevance):
        return sum(rel / np.log2(idx + 2) for idx, rel in enumerate(relevance))

    for i, seq in enumerate(sequences):
        if len(seq) < 2:
            continue
        target = seq[-1]
        candidates = np.argsort(scores[i])[::-1][:top_k]
        relevance = [1 if item == target else 0 for item in candidates]
        if target in candidates:
            hit += 1.0
        ideal_relevance = sorted(relevance, reverse=True)
        denom = dcg(ideal_relevance)
        if denom > 0:
            ndcg_total += dcg(relevance) / denom
        total += 1

    return {
        "hit@k": hit / total if total else 0.0,
        "ndcg@k": ndcg_total / total if total else 0.0,
    }


def hitrate(preds: List[List[str]], labels: List[List[str]], k: int = 10) -> float:
    """
    Computes Hit@K (basic hit rate).
    """
    hits = 0
    for p, l in zip(preds, labels):
        target = l[0]
        if target in p[:k]:
            hits += 1
    return hits / len(preds) if len(preds) > 0 else 0.0


def evaluate_faiss_index(index, item_embeddings, labels, topk=[5, 10]):
    """
    Evaluate a FAISS index with item embeddings and ground-truth labels.
    Arguments:
        - index: FAISS index object
        - item_embeddings: numpy array of shape (N, D)
        - labels: dict mapping query_id (int) to ground-truth item_id (int)
        - topk: list of cutoff values to evaluate (e.g., [5, 10])
    Returns:
        - dict with hit@k and ndcg@k for each k
    """
    results = {}
    hits = {k: [] for k in topk}
    ndcgs = {k: [] for k in topk}

    for user_id, true_item in labels.items():
        query_vec = item_embeddings[user_id].reshape(1, -1)
        _, indices = index.search(query_vec, max(topk))

        for k in topk:
            top_k = indices[0][:k]
            if true_item in top_k:
                hits[k].append(1)
                rank = np.where(top_k == true_item)[0][0]
                ndcgs[k].append(1 / np.log2(rank + 2))
            else:
                hits[k].append(0)
                ndcgs[k].append(0)

    for k in topk:
        results[f"hit@{k}"] = np.mean(hits[k])
        results[f"ndcg@{k}"] = np.mean(ndcgs[k])

    return results


def load_labels(path):
    """
    Load labels from JSON file
    """
    with open(path, "r") as f:
        return json.load(f)


# ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
# Additional Metric Functions for Logits-based Evaluation

def hitrate_at_k(predictions, ground_truth, k=10):
    hits = [1 if gt in pred[:k] else 0 for pred, gt in zip(predictions, ground_truth)]
    return np.mean(hits)

def ndcg_at_k(predictions, ground_truth, k=10):
    """
    predictions: list of list of predicted ASINs
    ground_truth: list of ground-truth ASINs
    """
    ndcgs = []
    for pred, gt in zip(predictions, ground_truth):
        pred = pred[:k]
        if gt in pred:
            rank = pred.index(gt)
            ndcg = 1.0 / np.log2(rank + 2)  # rank + 1 (0-based) + 1
        else:
            ndcg = 0.0
        ndcgs.append(ndcg)
    return np.mean(ndcgs)