File size: 3,075 Bytes
bc4fc5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Per-step KG link-prediction log-likelihood for the denoising loop.

Wraps the math of `KGLikelihoodMetric.update` (see research/COINs-KGGeneration
.../metrics/abstract_metrics.py) in a one-shot, stateless helper. We query
the frozen KG embedder + link ranker on the edges currently present in the
argmax reconstruction and return their mean log-sigmoid score — a positive
higher-is-better value that rises as the graph becomes cleaner.
"""

import logging

import torch
from torch.nn.functional import logsigmoid, one_hot

logger = logging.getLogger(__name__)


def kg_edge_log_likelihood(E_int, X, X_index, X_c, kg_experiment):
    """Mean log-sigmoid link-ranker score over edges currently present.

    E_int:   (n, n) long tensor. 0 = no edge; otherwise class = relation_id + 1.
    X:       (n, num_node_types) one-hot node types (unbatched, float).
    X_index: (n,) long dataset-global entity ids (unbatched).
    X_c:     (n,) long community ids (unbatched).
    kg_experiment: COINs experiment exposing .embedder, .link_ranker,
                   .loader.num_relations, .mini_batch_size, .device.

    Returns a Python float (log-likelihood per edge) or None if no edges are
    present or the scoring pass fails for any reason.
    """
    from graph_completion.graphs.preprocess import QueryData
    from graph_completion.graphs.queries import Query

    try:
        embedder = kg_experiment.embedder
        link_ranker = kg_experiment.link_ranker.link_ranker
        num_relations = kg_experiment.loader.num_relations
        kg_device = kg_experiment.device
        mini_batch_size = kg_experiment.mini_batch_size

        nz = E_int.nonzero(as_tuple=False)
        if nz.numel() == 0:
            return None
        nz = nz[nz[:, 0] != nz[:, 1]]
        if nz.numel() == 0:
            return None
        s, t = nz[:, 0], nz[:, 1]
        r = E_int[s, t] - 1

        e_s, e_t = X_index[s].long(), X_index[t].long()
        x_s, x_t = X[s].float(), X[t].float()
        c_s, c_t = X_c[s].long(), X_c[t].long()

        # Stable sort by (c_s, c_t) — the embedder batches by community pair.
        s_sort = torch.argsort(c_s)
        t_sort = torch.sort(c_t[s_sort], stable=True).indices
        pick = lambda v: v[s_sort][t_sort]
        e = [pick(e_s), pick(e_t)]
        x = [pick(x_s), pick(x_t)]
        c = [pick(c_s), pick(c_t)]
        r = pick(r)
        edge_attr = [one_hot(r, num_relations + 1).float()]

        q = Query("1p")
        q.build_query_tree()
        query_data = QueryData(q, e=e, x=x, c=c, edge_attr=edge_attr).to(kg_device)

        scores = []
        with torch.no_grad():
            for qd_batch in query_data.batch_split(mini_batch_size):
                q_emb, a_emb = embedder(qd_batch)
                scores.append(link_ranker(q_emb, a_emb))
        scores = torch.cat(scores, dim=0).view(-1)
        if scores.numel() == 0:
            return None
        return float(logsigmoid(scores).mean().item())
    except Exception as exc:
        logger.warning("[kg-likelihood] skipped: %s", exc)
        return None