website / src /backend /api /services /kg_likelihood.py
Andrej Janchevski
feat(kganomaly): add streaming denoising backend with KG-likelihood metric
bc4fc5c
"""Per-step KG link-prediction log-likelihood for the denoising loop.
Wraps the math of `KGLikelihoodMetric.update` (see research/COINs-KGGeneration
.../metrics/abstract_metrics.py) in a one-shot, stateless helper. We query
the frozen KG embedder + link ranker on the edges currently present in the
argmax reconstruction and return their mean log-sigmoid score — a positive
higher-is-better value that rises as the graph becomes cleaner.
"""
import logging
import torch
from torch.nn.functional import logsigmoid, one_hot
logger = logging.getLogger(__name__)
def kg_edge_log_likelihood(E_int, X, X_index, X_c, kg_experiment):
"""Mean log-sigmoid link-ranker score over edges currently present.
E_int: (n, n) long tensor. 0 = no edge; otherwise class = relation_id + 1.
X: (n, num_node_types) one-hot node types (unbatched, float).
X_index: (n,) long dataset-global entity ids (unbatched).
X_c: (n,) long community ids (unbatched).
kg_experiment: COINs experiment exposing .embedder, .link_ranker,
.loader.num_relations, .mini_batch_size, .device.
Returns a Python float (log-likelihood per edge) or None if no edges are
present or the scoring pass fails for any reason.
"""
from graph_completion.graphs.preprocess import QueryData
from graph_completion.graphs.queries import Query
try:
embedder = kg_experiment.embedder
link_ranker = kg_experiment.link_ranker.link_ranker
num_relations = kg_experiment.loader.num_relations
kg_device = kg_experiment.device
mini_batch_size = kg_experiment.mini_batch_size
nz = E_int.nonzero(as_tuple=False)
if nz.numel() == 0:
return None
nz = nz[nz[:, 0] != nz[:, 1]]
if nz.numel() == 0:
return None
s, t = nz[:, 0], nz[:, 1]
r = E_int[s, t] - 1
e_s, e_t = X_index[s].long(), X_index[t].long()
x_s, x_t = X[s].float(), X[t].float()
c_s, c_t = X_c[s].long(), X_c[t].long()
# Stable sort by (c_s, c_t) — the embedder batches by community pair.
s_sort = torch.argsort(c_s)
t_sort = torch.sort(c_t[s_sort], stable=True).indices
pick = lambda v: v[s_sort][t_sort]
e = [pick(e_s), pick(e_t)]
x = [pick(x_s), pick(x_t)]
c = [pick(c_s), pick(c_t)]
r = pick(r)
edge_attr = [one_hot(r, num_relations + 1).float()]
q = Query("1p")
q.build_query_tree()
query_data = QueryData(q, e=e, x=x, c=c, edge_attr=edge_attr).to(kg_device)
scores = []
with torch.no_grad():
for qd_batch in query_data.batch_split(mini_batch_size):
q_emb, a_emb = embedder(qd_batch)
scores.append(link_ranker(q_emb, a_emb))
scores = torch.cat(scores, dim=0).view(-1)
if scores.numel() == 0:
return None
return float(logsigmoid(scores).mean().item())
except Exception as exc:
logger.warning("[kg-likelihood] skipped: %s", exc)
return None