Andrej Janchevski
feat(kganomaly): add streaming denoising backend with KG-likelihood metric
bc4fc5c | """Per-step KG link-prediction log-likelihood for the denoising loop. | |
| Wraps the math of `KGLikelihoodMetric.update` (see research/COINs-KGGeneration | |
| .../metrics/abstract_metrics.py) in a one-shot, stateless helper. We query | |
| the frozen KG embedder + link ranker on the edges currently present in the | |
| argmax reconstruction and return their mean log-sigmoid score — a positive | |
| higher-is-better value that rises as the graph becomes cleaner. | |
| """ | |
| import logging | |
| import torch | |
| from torch.nn.functional import logsigmoid, one_hot | |
| logger = logging.getLogger(__name__) | |
| def kg_edge_log_likelihood(E_int, X, X_index, X_c, kg_experiment): | |
| """Mean log-sigmoid link-ranker score over edges currently present. | |
| E_int: (n, n) long tensor. 0 = no edge; otherwise class = relation_id + 1. | |
| X: (n, num_node_types) one-hot node types (unbatched, float). | |
| X_index: (n,) long dataset-global entity ids (unbatched). | |
| X_c: (n,) long community ids (unbatched). | |
| kg_experiment: COINs experiment exposing .embedder, .link_ranker, | |
| .loader.num_relations, .mini_batch_size, .device. | |
| Returns a Python float (log-likelihood per edge) or None if no edges are | |
| present or the scoring pass fails for any reason. | |
| """ | |
| from graph_completion.graphs.preprocess import QueryData | |
| from graph_completion.graphs.queries import Query | |
| try: | |
| embedder = kg_experiment.embedder | |
| link_ranker = kg_experiment.link_ranker.link_ranker | |
| num_relations = kg_experiment.loader.num_relations | |
| kg_device = kg_experiment.device | |
| mini_batch_size = kg_experiment.mini_batch_size | |
| nz = E_int.nonzero(as_tuple=False) | |
| if nz.numel() == 0: | |
| return None | |
| nz = nz[nz[:, 0] != nz[:, 1]] | |
| if nz.numel() == 0: | |
| return None | |
| s, t = nz[:, 0], nz[:, 1] | |
| r = E_int[s, t] - 1 | |
| e_s, e_t = X_index[s].long(), X_index[t].long() | |
| x_s, x_t = X[s].float(), X[t].float() | |
| c_s, c_t = X_c[s].long(), X_c[t].long() | |
| # Stable sort by (c_s, c_t) — the embedder batches by community pair. | |
| s_sort = torch.argsort(c_s) | |
| t_sort = torch.sort(c_t[s_sort], stable=True).indices | |
| pick = lambda v: v[s_sort][t_sort] | |
| e = [pick(e_s), pick(e_t)] | |
| x = [pick(x_s), pick(x_t)] | |
| c = [pick(c_s), pick(c_t)] | |
| r = pick(r) | |
| edge_attr = [one_hot(r, num_relations + 1).float()] | |
| q = Query("1p") | |
| q.build_query_tree() | |
| query_data = QueryData(q, e=e, x=x, c=c, edge_attr=edge_attr).to(kg_device) | |
| scores = [] | |
| with torch.no_grad(): | |
| for qd_batch in query_data.batch_split(mini_batch_size): | |
| q_emb, a_emb = embedder(qd_batch) | |
| scores.append(link_ranker(q_emb, a_emb)) | |
| scores = torch.cat(scores, dim=0).view(-1) | |
| if scores.numel() == 0: | |
| return None | |
| return float(logsigmoid(scores).mean().item()) | |
| except Exception as exc: | |
| logger.warning("[kg-likelihood] skipped: %s", exc) | |
| return None | |