File size: 3,832 Bytes
e49b23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score
from sentence_transformers import SentenceTransformer
import argparse
from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset, OGBNLinkPredNegDataset

BATCH_SIZE_EDGES = 100_000  # edge batching for scoring


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--custom-neg", action=argparse.BooleanOptionalAction, default=False
    )
    parser.add_argument(
        "--bert-embed", action=argparse.BooleanOptionalAction, default=False
    )
    return parser.parse_args()


@torch.no_grad()
def eval_edges_cos(global_emb, edge_index, edge_label, batch_size=BATCH_SIZE_EDGES):
    # edge_index shape: [2, M] with GLOBAL node ids; edge_label: [M] in {0,1}
    assert edge_index.dim() == 2 and edge_index.size(0) == 2
    assert edge_index.max() < global_emb.size(0), "Edge node id out of range."
    assert (edge_label == 0).any() and (edge_label == 1).any(), "Need both classes."

    scores_list, labels_list = [], []
    M = edge_index.size(1)
    for i in range(0, M, batch_size):
        j = min(i + batch_size, M)
        src = edge_index[0, i:j].to(global_emb.device)
        dst = edge_index[1, i:j].to(global_emb.device)
        scores = (global_emb[src] * global_emb[dst]).sum(
            dim=1
        )  # cosine (L2-normalized)
        scores_list.append(scores.float().cpu().numpy())
        labels_list.append(edge_label[i:j].cpu().numpy())
    y_scores = np.concatenate(scores_list)
    y_true = np.concatenate(labels_list)
    roc = roc_auc_score(y_true, y_scores)
    ap = average_precision_score(y_true, y_scores)
    return roc, ap


if __name__ == "__main__":
    args = parse_args()
    USE_CUSTOM_NEG = args.custom_neg
    USE_BERT_EMBED = args.bert_embed
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Load dataset + frozen embeddings ---
    if USE_CUSTOM_NEG:
        print("using hard negatives")
        dataset = OGBNLinkPredNegDataset(val_size=0.1, test_size=0.2)
    else:
        print("using random negatives")
        dataset = OGBNLinkPredDataset(val_size=0.1, test_size=0.2)
    if USE_BERT_EMBED:
        print("using BERT embeds")
        if Path("model/embeddings.pth").exists():
            emb = torch.load("model/embeddings.pth", map_location=DEVICE)
        else:
            st = SentenceTransformer("bongsoo/kpf-sbert-128d-v1", device=DEVICE)
            emb = st.encode(
                dataset.corpus, convert_to_tensor=True, show_progress_bar=True
            )
            Path("model").mkdir(parents=True, exist_ok=True)
            torch.save(emb, "model/embeddings.pth")
        emb = F.normalize(emb.to(DEVICE), p=2, dim=1)
    else:
        print("using skipgram embeds")
        emb = dataset.data.x

    train_data, val_data, test_data = dataset.get_splits()

    # Sanity checks
    for split_name, data in [
        ("train", train_data),
        ("val", val_data),
        ("test", test_data),
    ]:
        assert data.edge_label_index.size(1) == data.edge_label.size(0), (
            f"{split_name} size mismatch"
        )
        assert (data.edge_label == 0).any() and (data.edge_label == 1).any(), (
            f"{split_name} lacks negatives"
        )
        assert data.edge_label_index.max() < emb.size(0), (
            f"{split_name} has node ids >= num_nodes"
        )

    val_roc, val_ap = eval_edges_cos(
        emb, val_data.edge_label_index, val_data.edge_label
    )
    test_roc, test_ap = eval_edges_cos(
        emb, test_data.edge_label_index, test_data.edge_label
    )

    print(f"Val ROC-AUC:  {val_roc:.4f}, Val AP:  {val_ap:.4f}")
    print(f"Test ROC-AUC: {test_roc:.4f}, Test AP: {test_ap:.4f}")