"""Per-step KG link-prediction log-likelihood for the denoising loop. Wraps the math of `KGLikelihoodMetric.update` (see research/COINs-KGGeneration .../metrics/abstract_metrics.py) in a one-shot, stateless helper. We query the frozen KG embedder + link ranker on the edges currently present in the argmax reconstruction and return their mean log-sigmoid score — a positive higher-is-better value that rises as the graph becomes cleaner. """ import logging import torch from torch.nn.functional import logsigmoid, one_hot logger = logging.getLogger(__name__) def kg_edge_log_likelihood(E_int, X, X_index, X_c, kg_experiment): """Mean log-sigmoid link-ranker score over edges currently present. E_int: (n, n) long tensor. 0 = no edge; otherwise class = relation_id + 1. X: (n, num_node_types) one-hot node types (unbatched, float). X_index: (n,) long dataset-global entity ids (unbatched). X_c: (n,) long community ids (unbatched). kg_experiment: COINs experiment exposing .embedder, .link_ranker, .loader.num_relations, .mini_batch_size, .device. Returns a Python float (log-likelihood per edge) or None if no edges are present or the scoring pass fails for any reason. """ from graph_completion.graphs.preprocess import QueryData from graph_completion.graphs.queries import Query try: embedder = kg_experiment.embedder link_ranker = kg_experiment.link_ranker.link_ranker num_relations = kg_experiment.loader.num_relations kg_device = kg_experiment.device mini_batch_size = kg_experiment.mini_batch_size nz = E_int.nonzero(as_tuple=False) if nz.numel() == 0: return None nz = nz[nz[:, 0] != nz[:, 1]] if nz.numel() == 0: return None s, t = nz[:, 0], nz[:, 1] r = E_int[s, t] - 1 e_s, e_t = X_index[s].long(), X_index[t].long() x_s, x_t = X[s].float(), X[t].float() c_s, c_t = X_c[s].long(), X_c[t].long() # Stable sort by (c_s, c_t) — the embedder batches by community pair. s_sort = torch.argsort(c_s) t_sort = torch.sort(c_t[s_sort], stable=True).indices pick = lambda v: v[s_sort][t_sort] e = [pick(e_s), pick(e_t)] x = [pick(x_s), pick(x_t)] c = [pick(c_s), pick(c_t)] r = pick(r) edge_attr = [one_hot(r, num_relations + 1).float()] q = Query("1p") q.build_query_tree() query_data = QueryData(q, e=e, x=x, c=c, edge_attr=edge_attr).to(kg_device) scores = [] with torch.no_grad(): for qd_batch in query_data.batch_split(mini_batch_size): q_emb, a_emb = embedder(qd_batch) scores.append(link_ranker(q_emb, a_emb)) scores = torch.cat(scores, dim=0).view(-1) if scores.numel() == 0: return None return float(logsigmoid(scores).mean().item()) except Exception as exc: logger.warning("[kg-likelihood] skipped: %s", exc) return None