File size: 15,283 Bytes
3ad32ba 2d03843 3ad32ba c1b3cc7 3ad32ba c1b3cc7 3ad32ba c1b3cc7 3ad32ba c1b3cc7 3ad32ba c1b3cc7 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba 2d03843 3ad32ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 | """COINs link prediction / query answering β inference logic.
This module is intentionally free of Django settings and registry state.
It receives a fully-prepared Experiment object and runs the inference pipeline.
"""
import time
import numpy as np
import torch as pt
from torch.nn.functional import one_hot
from api.exceptions import InvalidRequestError
from api.services.constants import QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS, QUERY_STRUCTURES
from api.utils import clean_entity_name, clean_relation_name
# Research-code imports β available after settings.py adds the repos to sys.path.
from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int
from graph_completion.graphs.preprocess import QueryData
# Step-2 mini-batch cap: the link-ranker runs a full forward pass per candidate,
# and the hit community can hold thousands of nodes on Freebase. Splitting keeps
# CPU memory bounded without changing results.
SCORING_MINI_BATCH_SIZE = 512
def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
anchors, variables, relations_map, top_k):
"""Run a single COINs query and return the CoinsPredictResponse dict.
experiment β fully-prepared Experiment (embedder + link_ranker + loader loaded).
anchors β {api_node_id: entity_id} for all anchor nodes.
variables β {api_node_id: entity_id} for variable nodes (may be empty).
relations_map β {api_edge_id: relation_id} for all edges.
top_k β number of top predictions to return (max 10).
"""
embedder = experiment.embedder
link_ranker = experiment.link_ranker
loader = experiment.loader
device = next(embedder.parameters()).device
qs_mapping = QUERY_TREE_MAPPINGS[query_structure_id]
num_nodes = loader.num_nodes
num_relations = loader.num_relations
num_node_types = loader.num_node_types
num_communities = loader.num_communities
community_membership = embedder.community_membership # PT tensor [num_nodes], long
node_types_tensor = embedder.node_types # PT tensor [num_nodes], long
# Full-KG adjacency (train + val + test) β graph_indexes cover train only,
# which would miss answers in val/test splits.
adj_s_to_t = experiment.full_adj_s_to_t
# ---- Build Query tree ----
query = Query(QUERY_STRUCTURE_INTERNAL[query_structure_id])
query.build_query_tree()
query.get_node_cut()
num_tree_nodes = query.query_tree.vcount()
num_tree_edges = query.query_tree.ecount()
# ---- Build query_instance_mapped skeleton ----
qi_skeleton = query.query_tree.copy()
entities_skeleton = [-1] * num_tree_nodes # -1 = unresolved; avoids confusion with entity 0
for api_id, entity_id in anchors.items():
tree_idx = qs_mapping["nodes"][api_id]
entities_skeleton[tree_idx] = int(entity_id)
qi_skeleton.vs["e"] = entities_skeleton
for api_id, rel_id in relations_map.items():
edge_idx = qs_mapping["edges"][api_id]
qi_skeleton.es[edge_idx]["r"] = f"p{rel_id}"
# intersection edges keep "i" from build_query_tree
# ---- Resolve variable entities ----
_resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping,
variables, adj_s_to_t)
# ---- Step 1: Community scoring ----
t1_start = time.perf_counter()
all_communities = pt.arange(num_communities, dtype=pt.long, device=device)
e_c, edge_attr_c = [], []
for i in range(num_tree_nodes):
if i == query.query_answer:
e_c.append(all_communities)
else:
com = int(community_membership[entities_skeleton[i]].item())
e_c.append(pt.full([num_communities], com, dtype=pt.long, device=device))
for j in range(num_tree_edges):
r_label = qi_skeleton.es[j]["r"]
if "p" in r_label:
r_id = int(r_label[1:])
edge_attr_c.append(
one_hot(pt.full([num_communities], r_id, dtype=pt.long, device=device), num_relations + 1).float()
)
else:
# intersection edge β use the "no relation" slot (index num_relations)
edge_attr_c.append(
one_hot(
pt.full([num_communities], num_relations, dtype=pt.long, device=device), num_relations + 1
).float()
)
with pt.no_grad():
community_query = QueryData(query, e_c=e_c, edge_attr_c=edge_attr_c)
c_q_emb, c_a_emb = embedder.embed_communities(community_query)
community_scores = link_ranker(c_q_emb, c_a_emb, for_communities=True)
# Sort communities by descending score for the search loop
community_order = community_scores.argsort(descending=True) # [K]
step1_ms = (time.perf_counter() - t1_start) * 1000.0
# ---- Step 2: Score every entity in the hit community ----
# Mirrors rank_samples (experiments.py:782-786): the hit community is the
# first one in step-1 order that contains *any* KG-valid answer; within it
# we score all entities (minus anchors) so the link ranker actually has to
# discriminate, instead of being handed only the known answers.
t2_start = time.perf_counter()
valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t))
if not valid_answers:
raise InvalidRequestError(
"No entities in the knowledge graph satisfy this query"
)
anchor_entity_ids = {
entities_skeleton[i] for i in query.query_anchors if entities_skeleton[i] != -1
}
rank_c = 0 # 1-based rank of the hit community (0 = no hit)
c_err = 0 # sum of sizes of communities with better step-1 score than the hit
community_size = 0 # size of the hit community
candidates = []
for rank_0indexed in range(num_communities):
cid = int(community_order[rank_0indexed].item())
c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0]
c_size = int(c_entities_tensor.shape[0])
c_entities = [int(e.item()) for e in c_entities_tensor]
if any(e in valid_answers for e in c_entities):
rank_c = rank_0indexed + 1
community_size = c_size
candidates = [e for e in c_entities if e not in anchor_entity_ids]
break
c_err += c_size
if not candidates:
raise InvalidRequestError(
"No entities in the knowledge graph satisfy this query"
)
n_candidates = len(candidates)
candidates_tensor = pt.tensor(candidates, dtype=pt.long, device=device)
# Build per-tree-node entity columns. Anchor / resolved-variable positions
# repeat the same id across the batch; the answer and phantom-i positions
# take each candidate.
e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], []
for i in range(num_tree_nodes):
if entities_skeleton[i] == -1:
entities_i = candidates_tensor
else:
entities_i = pt.full([n_candidates], entities_skeleton[i], dtype=pt.long, device=device)
e_batch.append(entities_i)
x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float())
c_batch.append(community_membership[entities_i])
for j in range(num_tree_edges):
r_label = qi_skeleton.es[j]["r"]
if "p" in r_label:
r_id = int(r_label[1:])
edge_attr_batch.append(
one_hot(pt.full([n_candidates], r_id, dtype=pt.long, device=device), num_relations + 1).float()
)
else:
edge_attr_batch.append(
one_hot(pt.full([n_candidates], num_relations, dtype=pt.long, device=device), num_relations + 1).float()
)
with pt.no_grad():
batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch)
score_chunks = []
for chunk in batched_query.batch_split(SCORING_MINI_BATCH_SIZE):
q_emb, a_emb = embedder(chunk)
score_chunks.append(link_ranker(q_emb, a_emb).view(-1))
scores = pt.cat(score_chunks)
k = min(top_k, n_candidates)
top_scores, top_indices = scores.topk(k)
step2_ms = (time.perf_counter() - t2_start) * 1000.0
# rank = c_err + intra_community_rank (exact rank_samples formula)
inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
predictions = []
for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1):
entity_id = candidates[idx]
raw_name = str(inv_nodes.get(entity_id, entity_id))
predictions.append({
"intra_community_rank": intra_community_rank,
"rank": c_err + intra_community_rank,
"entity_id": entity_id,
"entity_name": clean_entity_name(raw_name, dataset_id),
"score": round(float(score), 4),
"is_valid_answer": entity_id in valid_answers,
})
total_ms = step1_ms + step2_ms
# Speedup accounts for the full search cost: K community scores + skipped communities + hit community
speedup = num_nodes / (num_communities + c_err + community_size)
return _build_response(dataset_id, algorithm, query_structure_id, anchors,
relations_map, qs_mapping, predictions, total_ms,
step1_ms, step2_ms, speedup, loader,
rank_c=rank_c, c_err=c_err)
def _resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping,
user_variables, adj_s_to_t):
"""Fill in unspecified variable entities via get_node_cut_cache.
Updates qi_skeleton and entities_skeleton in-place.
Runs up to 3 passes to handle 3p which has 2 variables.
"""
node_api_ids = {v: k for k, v in qs_mapping["nodes"].items()}
def _run_pass():
cache = get_node_cut_cache(qi_skeleton, query, adj_s_to_t)
resolved_any = False
for tree_node_idx, candidates in cache.items():
if tree_node_idx not in node_api_ids:
continue
api_id = node_api_ids[tree_node_idx]
if entities_skeleton[tree_node_idx] != -1:
continue # anchor or already-resolved variable
if api_id in user_variables:
entity = int(user_variables[api_id])
elif not candidates:
raise InvalidRequestError(
f"No valid entity exists at variable '{api_id}' "
"for the given anchors and relations"
)
else:
entity = int(np.random.choice(candidates))
entities_skeleton[tree_node_idx] = entity
qi_skeleton.vs[tree_node_idx]["e"] = entity
resolved_any = True
return resolved_any
for _ in range(3):
if not _run_pass():
break
# Post-processing: propagate resolved entities to nodes that get_node_cut_cache
# can't reach (they lie between the cut and the anchors or answer in the tree).
# Two cases:
# - Phantom "i" nodes (e.g. ip nodes 1,3): inherit their resolved ancestor's entity
# - Non-cut "p" variable nodes (e.g. 3p v1): resolve via adj_s_to_t[child_entity][rel]
changed = True
while changed:
changed = False
for i in range(query.query_tree.vcount()):
if i == query.query_answer or entities_skeleton[i] != -1:
continue
in_edges = query.query_tree.es[query.query_tree.incident(i, mode="in")]
if not in_edges:
continue
in_edge = in_edges[0]
if in_edge["r"] == "i":
# Phantom intersection node: copy the resolved parent (intersection var)
parent = in_edge.source
if entities_skeleton[parent] != -1:
entities_skeleton[i] = entities_skeleton[parent]
qi_skeleton.vs[i]["e"] = entities_skeleton[i]
changed = True
elif "p" in in_edge["r"]:
# Projection variable between cut and anchors: resolve from child.
# Use qi_skeleton.es to get the concrete "p{rel_id}" label (not the
# generic "p" from query.query_tree that would cause int("") errors).
out_edges = qi_skeleton.es[qi_skeleton.incident(i, mode="out")]
if out_edges:
child = out_edges[0].target
if entities_skeleton[child] != -1:
rel_id = query_edge_r_to_int(out_edges[0]["r"])
candidates = list(adj_s_to_t.get(entities_skeleton[child], {}).get(rel_id, []))
if not candidates:
api_id = node_api_ids.get(i, str(i))
raise InvalidRequestError(
f"No valid entity exists at variable '{api_id}' "
"for the given anchors and relations"
)
entities_skeleton[i] = int(np.random.choice(candidates))
qi_skeleton.vs[i]["e"] = entities_skeleton[i]
changed = True
def _build_response(dataset_id, algorithm, query_structure_id, anchors, relations_map,
qs_mapping, predictions, total_ms, step1_ms, step2_ms, speedup, loader,
rank_c=1, c_err=0):
"""Assemble the CoinsPredictResponse dict."""
inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure_id)
# Build a name map for each node in the template
node_name = {}
for node in qs_template["nodes"]:
if node["type"] == "anchor":
raw = str(inv_nodes.get(anchors[node["id"]], anchors[node["id"]]))
node_name[node["id"]] = clean_entity_name(raw, dataset_id)
elif node["type"] == "variable":
node_name[node["id"]] = "(variable)"
else:
node_name[node["id"]] = "?"
# Describe each edge; comma-separate for a clean multi-pattern display
edge_parts = []
for edge in qs_template["edges"]:
rel_id = relations_map[edge["id"]]
rel = clean_relation_name(str(inv_relations.get(rel_id, rel_id)), dataset_id)
edge_parts.append(f"{node_name[edge['source']]} --[{rel}]--> {node_name[edge['target']]}")
query_description = ", ".join(edge_parts)
return {
"dataset_id": dataset_id,
"algorithm": algorithm,
"query_structure": query_structure_id,
"query_description": query_description,
"predictions": predictions,
"timing": {
"step1_ms": round(step1_ms, 1),
"step1_label": "Community detection + localized embedding",
"step2_ms": round(step2_ms, 1),
"step2_label": "Link prediction",
"total_ms": round(total_ms, 1),
"rank_c": rank_c,
"baseline_estimate_ms": round(total_ms * speedup, 1),
"baseline_label": "Estimated baseline (without COINs)",
"speedup": round(speedup, 2),
},
}
|