"""COINs link prediction / query answering — inference logic. This module is intentionally free of Django settings and registry state. It receives a fully-prepared Experiment object and runs the inference pipeline. """ import time import numpy as np import torch as pt from torch.nn.functional import one_hot from api.exceptions import InvalidRequestError from api.services.constants import QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS, QUERY_STRUCTURES from api.utils import clean_entity_name, clean_relation_name # Research-code imports — available after settings.py adds the repos to sys.path. from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int from graph_completion.graphs.preprocess import QueryData # Step-2 mini-batch cap: the link-ranker runs a full forward pass per candidate, # and the hit community can hold thousands of nodes on Freebase. Splitting keeps # CPU memory bounded without changing results. SCORING_MINI_BATCH_SIZE = 512 def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id, anchors, variables, relations_map, top_k): """Run a single COINs query and return the CoinsPredictResponse dict. experiment — fully-prepared Experiment (embedder + link_ranker + loader loaded). anchors — {api_node_id: entity_id} for all anchor nodes. variables — {api_node_id: entity_id} for variable nodes (may be empty). relations_map — {api_edge_id: relation_id} for all edges. top_k — number of top predictions to return (max 10). """ embedder = experiment.embedder link_ranker = experiment.link_ranker loader = experiment.loader device = next(embedder.parameters()).device qs_mapping = QUERY_TREE_MAPPINGS[query_structure_id] num_nodes = loader.num_nodes num_relations = loader.num_relations num_node_types = loader.num_node_types num_communities = loader.num_communities community_membership = embedder.community_membership # PT tensor [num_nodes], long node_types_tensor = embedder.node_types # PT tensor [num_nodes], long # Full-KG adjacency (train + val + test) — graph_indexes cover train only, # which would miss answers in val/test splits. adj_s_to_t = experiment.full_adj_s_to_t # ---- Build Query tree ---- query = Query(QUERY_STRUCTURE_INTERNAL[query_structure_id]) query.build_query_tree() query.get_node_cut() num_tree_nodes = query.query_tree.vcount() num_tree_edges = query.query_tree.ecount() # ---- Build query_instance_mapped skeleton ---- qi_skeleton = query.query_tree.copy() entities_skeleton = [-1] * num_tree_nodes # -1 = unresolved; avoids confusion with entity 0 for api_id, entity_id in anchors.items(): tree_idx = qs_mapping["nodes"][api_id] entities_skeleton[tree_idx] = int(entity_id) qi_skeleton.vs["e"] = entities_skeleton for api_id, rel_id in relations_map.items(): edge_idx = qs_mapping["edges"][api_id] qi_skeleton.es[edge_idx]["r"] = f"p{rel_id}" # intersection edges keep "i" from build_query_tree # ---- Resolve variable entities ---- _resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping, variables, adj_s_to_t) # ---- Step 1: Community scoring ---- t1_start = time.perf_counter() all_communities = pt.arange(num_communities, dtype=pt.long, device=device) e_c, edge_attr_c = [], [] for i in range(num_tree_nodes): if i == query.query_answer: e_c.append(all_communities) else: com = int(community_membership[entities_skeleton[i]].item()) e_c.append(pt.full([num_communities], com, dtype=pt.long, device=device)) for j in range(num_tree_edges): r_label = qi_skeleton.es[j]["r"] if "p" in r_label: r_id = int(r_label[1:]) edge_attr_c.append( one_hot(pt.full([num_communities], r_id, dtype=pt.long, device=device), num_relations + 1).float() ) else: # intersection edge — use the "no relation" slot (index num_relations) edge_attr_c.append( one_hot( pt.full([num_communities], num_relations, dtype=pt.long, device=device), num_relations + 1 ).float() ) with pt.no_grad(): community_query = QueryData(query, e_c=e_c, edge_attr_c=edge_attr_c) c_q_emb, c_a_emb = embedder.embed_communities(community_query) community_scores = link_ranker(c_q_emb, c_a_emb, for_communities=True) # Sort communities by descending score for the search loop community_order = community_scores.argsort(descending=True) # [K] step1_ms = (time.perf_counter() - t1_start) * 1000.0 # ---- Step 2: Score every entity in the hit community ---- # Mirrors rank_samples (experiments.py:782-786): the hit community is the # first one in step-1 order that contains *any* KG-valid answer; within it # we score all entities (minus anchors) so the link ranker actually has to # discriminate, instead of being handed only the known answers. t2_start = time.perf_counter() valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t)) if not valid_answers: raise InvalidRequestError( "No entities in the knowledge graph satisfy this query" ) anchor_entity_ids = { entities_skeleton[i] for i in query.query_anchors if entities_skeleton[i] != -1 } rank_c = 0 # 1-based rank of the hit community (0 = no hit) c_err = 0 # sum of sizes of communities with better step-1 score than the hit community_size = 0 # size of the hit community candidates = [] for rank_0indexed in range(num_communities): cid = int(community_order[rank_0indexed].item()) c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0] c_size = int(c_entities_tensor.shape[0]) c_entities = [int(e.item()) for e in c_entities_tensor] if any(e in valid_answers for e in c_entities): rank_c = rank_0indexed + 1 community_size = c_size candidates = [e for e in c_entities if e not in anchor_entity_ids] break c_err += c_size if not candidates: raise InvalidRequestError( "No entities in the knowledge graph satisfy this query" ) n_candidates = len(candidates) candidates_tensor = pt.tensor(candidates, dtype=pt.long, device=device) # Build per-tree-node entity columns. Anchor / resolved-variable positions # repeat the same id across the batch; the answer and phantom-i positions # take each candidate. e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], [] for i in range(num_tree_nodes): if entities_skeleton[i] == -1: entities_i = candidates_tensor else: entities_i = pt.full([n_candidates], entities_skeleton[i], dtype=pt.long, device=device) e_batch.append(entities_i) x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float()) c_batch.append(community_membership[entities_i]) for j in range(num_tree_edges): r_label = qi_skeleton.es[j]["r"] if "p" in r_label: r_id = int(r_label[1:]) edge_attr_batch.append( one_hot(pt.full([n_candidates], r_id, dtype=pt.long, device=device), num_relations + 1).float() ) else: edge_attr_batch.append( one_hot(pt.full([n_candidates], num_relations, dtype=pt.long, device=device), num_relations + 1).float() ) with pt.no_grad(): batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch) score_chunks = [] for chunk in batched_query.batch_split(SCORING_MINI_BATCH_SIZE): q_emb, a_emb = embedder(chunk) score_chunks.append(link_ranker(q_emb, a_emb).view(-1)) scores = pt.cat(score_chunks) k = min(top_k, n_candidates) top_scores, top_indices = scores.topk(k) step2_ms = (time.perf_counter() - t2_start) * 1000.0 # rank = c_err + intra_community_rank (exact rank_samples formula) inv_nodes, _, _ = loader.dataset.get_inverted_name_maps() predictions = [] for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1): entity_id = candidates[idx] raw_name = str(inv_nodes.get(entity_id, entity_id)) predictions.append({ "intra_community_rank": intra_community_rank, "rank": c_err + intra_community_rank, "entity_id": entity_id, "entity_name": clean_entity_name(raw_name, dataset_id), "score": round(float(score), 4), "is_valid_answer": entity_id in valid_answers, }) total_ms = step1_ms + step2_ms # Speedup accounts for the full search cost: K community scores + skipped communities + hit community speedup = num_nodes / (num_communities + c_err + community_size) return _build_response(dataset_id, algorithm, query_structure_id, anchors, relations_map, qs_mapping, predictions, total_ms, step1_ms, step2_ms, speedup, loader, rank_c=rank_c, c_err=c_err) def _resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping, user_variables, adj_s_to_t): """Fill in unspecified variable entities via get_node_cut_cache. Updates qi_skeleton and entities_skeleton in-place. Runs up to 3 passes to handle 3p which has 2 variables. """ node_api_ids = {v: k for k, v in qs_mapping["nodes"].items()} def _run_pass(): cache = get_node_cut_cache(qi_skeleton, query, adj_s_to_t) resolved_any = False for tree_node_idx, candidates in cache.items(): if tree_node_idx not in node_api_ids: continue api_id = node_api_ids[tree_node_idx] if entities_skeleton[tree_node_idx] != -1: continue # anchor or already-resolved variable if api_id in user_variables: entity = int(user_variables[api_id]) elif not candidates: raise InvalidRequestError( f"No valid entity exists at variable '{api_id}' " "for the given anchors and relations" ) else: entity = int(np.random.choice(candidates)) entities_skeleton[tree_node_idx] = entity qi_skeleton.vs[tree_node_idx]["e"] = entity resolved_any = True return resolved_any for _ in range(3): if not _run_pass(): break # Post-processing: propagate resolved entities to nodes that get_node_cut_cache # can't reach (they lie between the cut and the anchors or answer in the tree). # Two cases: # - Phantom "i" nodes (e.g. ip nodes 1,3): inherit their resolved ancestor's entity # - Non-cut "p" variable nodes (e.g. 3p v1): resolve via adj_s_to_t[child_entity][rel] changed = True while changed: changed = False for i in range(query.query_tree.vcount()): if i == query.query_answer or entities_skeleton[i] != -1: continue in_edges = query.query_tree.es[query.query_tree.incident(i, mode="in")] if not in_edges: continue in_edge = in_edges[0] if in_edge["r"] == "i": # Phantom intersection node: copy the resolved parent (intersection var) parent = in_edge.source if entities_skeleton[parent] != -1: entities_skeleton[i] = entities_skeleton[parent] qi_skeleton.vs[i]["e"] = entities_skeleton[i] changed = True elif "p" in in_edge["r"]: # Projection variable between cut and anchors: resolve from child. # Use qi_skeleton.es to get the concrete "p{rel_id}" label (not the # generic "p" from query.query_tree that would cause int("") errors). out_edges = qi_skeleton.es[qi_skeleton.incident(i, mode="out")] if out_edges: child = out_edges[0].target if entities_skeleton[child] != -1: rel_id = query_edge_r_to_int(out_edges[0]["r"]) candidates = list(adj_s_to_t.get(entities_skeleton[child], {}).get(rel_id, [])) if not candidates: api_id = node_api_ids.get(i, str(i)) raise InvalidRequestError( f"No valid entity exists at variable '{api_id}' " "for the given anchors and relations" ) entities_skeleton[i] = int(np.random.choice(candidates)) qi_skeleton.vs[i]["e"] = entities_skeleton[i] changed = True def _build_response(dataset_id, algorithm, query_structure_id, anchors, relations_map, qs_mapping, predictions, total_ms, step1_ms, step2_ms, speedup, loader, rank_c=1, c_err=0): """Assemble the CoinsPredictResponse dict.""" inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps() qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure_id) # Build a name map for each node in the template node_name = {} for node in qs_template["nodes"]: if node["type"] == "anchor": raw = str(inv_nodes.get(anchors[node["id"]], anchors[node["id"]])) node_name[node["id"]] = clean_entity_name(raw, dataset_id) elif node["type"] == "variable": node_name[node["id"]] = "(variable)" else: node_name[node["id"]] = "?" # Describe each edge; comma-separate for a clean multi-pattern display edge_parts = [] for edge in qs_template["edges"]: rel_id = relations_map[edge["id"]] rel = clean_relation_name(str(inv_relations.get(rel_id, rel_id)), dataset_id) edge_parts.append(f"{node_name[edge['source']]} --[{rel}]--> {node_name[edge['target']]}") query_description = ", ".join(edge_parts) return { "dataset_id": dataset_id, "algorithm": algorithm, "query_structure": query_structure_id, "query_description": query_description, "predictions": predictions, "timing": { "step1_ms": round(step1_ms, 1), "step1_label": "Community detection + localized embedding", "step2_ms": round(step2_ms, 1), "step2_label": "Link prediction", "total_ms": round(total_ms, 1), "rank_c": rank_c, "baseline_estimate_ms": round(total_ms * speedup, 1), "baseline_label": "Estimated baseline (without COINs)", "speedup": round(speedup, 2), }, }