website / src /backend /api /services /coins_inference.py
Andrej Janchevski
fix(coins): score full hit community instead of only valid answers
2d03843
"""COINs link prediction / query answering — inference logic.
This module is intentionally free of Django settings and registry state.
It receives a fully-prepared Experiment object and runs the inference pipeline.
"""
import time
import numpy as np
import torch as pt
from torch.nn.functional import one_hot
from api.exceptions import InvalidRequestError
from api.services.constants import QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS, QUERY_STRUCTURES
from api.utils import clean_entity_name, clean_relation_name
# Research-code imports — available after settings.py adds the repos to sys.path.
from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int
from graph_completion.graphs.preprocess import QueryData
# Step-2 mini-batch cap: the link-ranker runs a full forward pass per candidate,
# and the hit community can hold thousands of nodes on Freebase. Splitting keeps
# CPU memory bounded without changing results.
SCORING_MINI_BATCH_SIZE = 512
def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
anchors, variables, relations_map, top_k):
"""Run a single COINs query and return the CoinsPredictResponse dict.
experiment — fully-prepared Experiment (embedder + link_ranker + loader loaded).
anchors — {api_node_id: entity_id} for all anchor nodes.
variables — {api_node_id: entity_id} for variable nodes (may be empty).
relations_map — {api_edge_id: relation_id} for all edges.
top_k — number of top predictions to return (max 10).
"""
embedder = experiment.embedder
link_ranker = experiment.link_ranker
loader = experiment.loader
device = next(embedder.parameters()).device
qs_mapping = QUERY_TREE_MAPPINGS[query_structure_id]
num_nodes = loader.num_nodes
num_relations = loader.num_relations
num_node_types = loader.num_node_types
num_communities = loader.num_communities
community_membership = embedder.community_membership # PT tensor [num_nodes], long
node_types_tensor = embedder.node_types # PT tensor [num_nodes], long
# Full-KG adjacency (train + val + test) — graph_indexes cover train only,
# which would miss answers in val/test splits.
adj_s_to_t = experiment.full_adj_s_to_t
# ---- Build Query tree ----
query = Query(QUERY_STRUCTURE_INTERNAL[query_structure_id])
query.build_query_tree()
query.get_node_cut()
num_tree_nodes = query.query_tree.vcount()
num_tree_edges = query.query_tree.ecount()
# ---- Build query_instance_mapped skeleton ----
qi_skeleton = query.query_tree.copy()
entities_skeleton = [-1] * num_tree_nodes # -1 = unresolved; avoids confusion with entity 0
for api_id, entity_id in anchors.items():
tree_idx = qs_mapping["nodes"][api_id]
entities_skeleton[tree_idx] = int(entity_id)
qi_skeleton.vs["e"] = entities_skeleton
for api_id, rel_id in relations_map.items():
edge_idx = qs_mapping["edges"][api_id]
qi_skeleton.es[edge_idx]["r"] = f"p{rel_id}"
# intersection edges keep "i" from build_query_tree
# ---- Resolve variable entities ----
_resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping,
variables, adj_s_to_t)
# ---- Step 1: Community scoring ----
t1_start = time.perf_counter()
all_communities = pt.arange(num_communities, dtype=pt.long, device=device)
e_c, edge_attr_c = [], []
for i in range(num_tree_nodes):
if i == query.query_answer:
e_c.append(all_communities)
else:
com = int(community_membership[entities_skeleton[i]].item())
e_c.append(pt.full([num_communities], com, dtype=pt.long, device=device))
for j in range(num_tree_edges):
r_label = qi_skeleton.es[j]["r"]
if "p" in r_label:
r_id = int(r_label[1:])
edge_attr_c.append(
one_hot(pt.full([num_communities], r_id, dtype=pt.long, device=device), num_relations + 1).float()
)
else:
# intersection edge — use the "no relation" slot (index num_relations)
edge_attr_c.append(
one_hot(
pt.full([num_communities], num_relations, dtype=pt.long, device=device), num_relations + 1
).float()
)
with pt.no_grad():
community_query = QueryData(query, e_c=e_c, edge_attr_c=edge_attr_c)
c_q_emb, c_a_emb = embedder.embed_communities(community_query)
community_scores = link_ranker(c_q_emb, c_a_emb, for_communities=True)
# Sort communities by descending score for the search loop
community_order = community_scores.argsort(descending=True) # [K]
step1_ms = (time.perf_counter() - t1_start) * 1000.0
# ---- Step 2: Score every entity in the hit community ----
# Mirrors rank_samples (experiments.py:782-786): the hit community is the
# first one in step-1 order that contains *any* KG-valid answer; within it
# we score all entities (minus anchors) so the link ranker actually has to
# discriminate, instead of being handed only the known answers.
t2_start = time.perf_counter()
valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t))
if not valid_answers:
raise InvalidRequestError(
"No entities in the knowledge graph satisfy this query"
)
anchor_entity_ids = {
entities_skeleton[i] for i in query.query_anchors if entities_skeleton[i] != -1
}
rank_c = 0 # 1-based rank of the hit community (0 = no hit)
c_err = 0 # sum of sizes of communities with better step-1 score than the hit
community_size = 0 # size of the hit community
candidates = []
for rank_0indexed in range(num_communities):
cid = int(community_order[rank_0indexed].item())
c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0]
c_size = int(c_entities_tensor.shape[0])
c_entities = [int(e.item()) for e in c_entities_tensor]
if any(e in valid_answers for e in c_entities):
rank_c = rank_0indexed + 1
community_size = c_size
candidates = [e for e in c_entities if e not in anchor_entity_ids]
break
c_err += c_size
if not candidates:
raise InvalidRequestError(
"No entities in the knowledge graph satisfy this query"
)
n_candidates = len(candidates)
candidates_tensor = pt.tensor(candidates, dtype=pt.long, device=device)
# Build per-tree-node entity columns. Anchor / resolved-variable positions
# repeat the same id across the batch; the answer and phantom-i positions
# take each candidate.
e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], []
for i in range(num_tree_nodes):
if entities_skeleton[i] == -1:
entities_i = candidates_tensor
else:
entities_i = pt.full([n_candidates], entities_skeleton[i], dtype=pt.long, device=device)
e_batch.append(entities_i)
x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float())
c_batch.append(community_membership[entities_i])
for j in range(num_tree_edges):
r_label = qi_skeleton.es[j]["r"]
if "p" in r_label:
r_id = int(r_label[1:])
edge_attr_batch.append(
one_hot(pt.full([n_candidates], r_id, dtype=pt.long, device=device), num_relations + 1).float()
)
else:
edge_attr_batch.append(
one_hot(pt.full([n_candidates], num_relations, dtype=pt.long, device=device), num_relations + 1).float()
)
with pt.no_grad():
batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch)
score_chunks = []
for chunk in batched_query.batch_split(SCORING_MINI_BATCH_SIZE):
q_emb, a_emb = embedder(chunk)
score_chunks.append(link_ranker(q_emb, a_emb).view(-1))
scores = pt.cat(score_chunks)
k = min(top_k, n_candidates)
top_scores, top_indices = scores.topk(k)
step2_ms = (time.perf_counter() - t2_start) * 1000.0
# rank = c_err + intra_community_rank (exact rank_samples formula)
inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
predictions = []
for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1):
entity_id = candidates[idx]
raw_name = str(inv_nodes.get(entity_id, entity_id))
predictions.append({
"intra_community_rank": intra_community_rank,
"rank": c_err + intra_community_rank,
"entity_id": entity_id,
"entity_name": clean_entity_name(raw_name, dataset_id),
"score": round(float(score), 4),
"is_valid_answer": entity_id in valid_answers,
})
total_ms = step1_ms + step2_ms
# Speedup accounts for the full search cost: K community scores + skipped communities + hit community
speedup = num_nodes / (num_communities + c_err + community_size)
return _build_response(dataset_id, algorithm, query_structure_id, anchors,
relations_map, qs_mapping, predictions, total_ms,
step1_ms, step2_ms, speedup, loader,
rank_c=rank_c, c_err=c_err)
def _resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping,
user_variables, adj_s_to_t):
"""Fill in unspecified variable entities via get_node_cut_cache.
Updates qi_skeleton and entities_skeleton in-place.
Runs up to 3 passes to handle 3p which has 2 variables.
"""
node_api_ids = {v: k for k, v in qs_mapping["nodes"].items()}
def _run_pass():
cache = get_node_cut_cache(qi_skeleton, query, adj_s_to_t)
resolved_any = False
for tree_node_idx, candidates in cache.items():
if tree_node_idx not in node_api_ids:
continue
api_id = node_api_ids[tree_node_idx]
if entities_skeleton[tree_node_idx] != -1:
continue # anchor or already-resolved variable
if api_id in user_variables:
entity = int(user_variables[api_id])
elif not candidates:
raise InvalidRequestError(
f"No valid entity exists at variable '{api_id}' "
"for the given anchors and relations"
)
else:
entity = int(np.random.choice(candidates))
entities_skeleton[tree_node_idx] = entity
qi_skeleton.vs[tree_node_idx]["e"] = entity
resolved_any = True
return resolved_any
for _ in range(3):
if not _run_pass():
break
# Post-processing: propagate resolved entities to nodes that get_node_cut_cache
# can't reach (they lie between the cut and the anchors or answer in the tree).
# Two cases:
# - Phantom "i" nodes (e.g. ip nodes 1,3): inherit their resolved ancestor's entity
# - Non-cut "p" variable nodes (e.g. 3p v1): resolve via adj_s_to_t[child_entity][rel]
changed = True
while changed:
changed = False
for i in range(query.query_tree.vcount()):
if i == query.query_answer or entities_skeleton[i] != -1:
continue
in_edges = query.query_tree.es[query.query_tree.incident(i, mode="in")]
if not in_edges:
continue
in_edge = in_edges[0]
if in_edge["r"] == "i":
# Phantom intersection node: copy the resolved parent (intersection var)
parent = in_edge.source
if entities_skeleton[parent] != -1:
entities_skeleton[i] = entities_skeleton[parent]
qi_skeleton.vs[i]["e"] = entities_skeleton[i]
changed = True
elif "p" in in_edge["r"]:
# Projection variable between cut and anchors: resolve from child.
# Use qi_skeleton.es to get the concrete "p{rel_id}" label (not the
# generic "p" from query.query_tree that would cause int("") errors).
out_edges = qi_skeleton.es[qi_skeleton.incident(i, mode="out")]
if out_edges:
child = out_edges[0].target
if entities_skeleton[child] != -1:
rel_id = query_edge_r_to_int(out_edges[0]["r"])
candidates = list(adj_s_to_t.get(entities_skeleton[child], {}).get(rel_id, []))
if not candidates:
api_id = node_api_ids.get(i, str(i))
raise InvalidRequestError(
f"No valid entity exists at variable '{api_id}' "
"for the given anchors and relations"
)
entities_skeleton[i] = int(np.random.choice(candidates))
qi_skeleton.vs[i]["e"] = entities_skeleton[i]
changed = True
def _build_response(dataset_id, algorithm, query_structure_id, anchors, relations_map,
qs_mapping, predictions, total_ms, step1_ms, step2_ms, speedup, loader,
rank_c=1, c_err=0):
"""Assemble the CoinsPredictResponse dict."""
inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure_id)
# Build a name map for each node in the template
node_name = {}
for node in qs_template["nodes"]:
if node["type"] == "anchor":
raw = str(inv_nodes.get(anchors[node["id"]], anchors[node["id"]]))
node_name[node["id"]] = clean_entity_name(raw, dataset_id)
elif node["type"] == "variable":
node_name[node["id"]] = "(variable)"
else:
node_name[node["id"]] = "?"
# Describe each edge; comma-separate for a clean multi-pattern display
edge_parts = []
for edge in qs_template["edges"]:
rel_id = relations_map[edge["id"]]
rel = clean_relation_name(str(inv_relations.get(rel_id, rel_id)), dataset_id)
edge_parts.append(f"{node_name[edge['source']]} --[{rel}]--> {node_name[edge['target']]}")
query_description = ", ".join(edge_parts)
return {
"dataset_id": dataset_id,
"algorithm": algorithm,
"query_structure": query_structure_id,
"query_description": query_description,
"predictions": predictions,
"timing": {
"step1_ms": round(step1_ms, 1),
"step1_label": "Community detection + localized embedding",
"step2_ms": round(step2_ms, 1),
"step2_label": "Link prediction",
"total_ms": round(total_ms, 1),
"rank_c": rank_c,
"baseline_estimate_ms": round(total_ms * speedup, 1),
"baseline_label": "Estimated baseline (without COINs)",
"speedup": round(speedup, 2),
},
}