| """COINs link prediction / query answering — inference logic. |
| |
| This module is intentionally free of Django settings and registry state. |
| It receives a fully-prepared Experiment object and runs the inference pipeline. |
| """ |
| import time |
|
|
| import numpy as np |
| import torch as pt |
| from torch.nn.functional import one_hot |
|
|
| from api.exceptions import InvalidRequestError |
| from api.services.constants import QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS, QUERY_STRUCTURES |
| from api.utils import clean_entity_name, clean_relation_name |
|
|
| |
| from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int |
| from graph_completion.graphs.preprocess import QueryData |
|
|
| |
| |
| |
| SCORING_MINI_BATCH_SIZE = 512 |
|
|
|
|
| def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id, |
| anchors, variables, relations_map, top_k): |
| """Run a single COINs query and return the CoinsPredictResponse dict. |
| |
| experiment — fully-prepared Experiment (embedder + link_ranker + loader loaded). |
| anchors — {api_node_id: entity_id} for all anchor nodes. |
| variables — {api_node_id: entity_id} for variable nodes (may be empty). |
| relations_map — {api_edge_id: relation_id} for all edges. |
| top_k — number of top predictions to return (max 10). |
| """ |
| embedder = experiment.embedder |
| link_ranker = experiment.link_ranker |
| loader = experiment.loader |
| device = next(embedder.parameters()).device |
|
|
| qs_mapping = QUERY_TREE_MAPPINGS[query_structure_id] |
| num_nodes = loader.num_nodes |
| num_relations = loader.num_relations |
| num_node_types = loader.num_node_types |
| num_communities = loader.num_communities |
| community_membership = embedder.community_membership |
| node_types_tensor = embedder.node_types |
|
|
| |
| |
| adj_s_to_t = experiment.full_adj_s_to_t |
|
|
| |
| query = Query(QUERY_STRUCTURE_INTERNAL[query_structure_id]) |
| query.build_query_tree() |
| query.get_node_cut() |
|
|
| num_tree_nodes = query.query_tree.vcount() |
| num_tree_edges = query.query_tree.ecount() |
|
|
| |
| qi_skeleton = query.query_tree.copy() |
| entities_skeleton = [-1] * num_tree_nodes |
| for api_id, entity_id in anchors.items(): |
| tree_idx = qs_mapping["nodes"][api_id] |
| entities_skeleton[tree_idx] = int(entity_id) |
| qi_skeleton.vs["e"] = entities_skeleton |
| for api_id, rel_id in relations_map.items(): |
| edge_idx = qs_mapping["edges"][api_id] |
| qi_skeleton.es[edge_idx]["r"] = f"p{rel_id}" |
| |
|
|
| |
| _resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping, |
| variables, adj_s_to_t) |
|
|
| |
| t1_start = time.perf_counter() |
|
|
| all_communities = pt.arange(num_communities, dtype=pt.long, device=device) |
| e_c, edge_attr_c = [], [] |
| for i in range(num_tree_nodes): |
| if i == query.query_answer: |
| e_c.append(all_communities) |
| else: |
| com = int(community_membership[entities_skeleton[i]].item()) |
| e_c.append(pt.full([num_communities], com, dtype=pt.long, device=device)) |
| for j in range(num_tree_edges): |
| r_label = qi_skeleton.es[j]["r"] |
| if "p" in r_label: |
| r_id = int(r_label[1:]) |
| edge_attr_c.append( |
| one_hot(pt.full([num_communities], r_id, dtype=pt.long, device=device), num_relations + 1).float() |
| ) |
| else: |
| |
| edge_attr_c.append( |
| one_hot( |
| pt.full([num_communities], num_relations, dtype=pt.long, device=device), num_relations + 1 |
| ).float() |
| ) |
|
|
| with pt.no_grad(): |
| community_query = QueryData(query, e_c=e_c, edge_attr_c=edge_attr_c) |
| c_q_emb, c_a_emb = embedder.embed_communities(community_query) |
| community_scores = link_ranker(c_q_emb, c_a_emb, for_communities=True) |
|
|
| |
| community_order = community_scores.argsort(descending=True) |
| step1_ms = (time.perf_counter() - t1_start) * 1000.0 |
|
|
| |
| |
| |
| |
| |
| t2_start = time.perf_counter() |
|
|
| valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t)) |
| if not valid_answers: |
| raise InvalidRequestError( |
| "No entities in the knowledge graph satisfy this query" |
| ) |
|
|
| anchor_entity_ids = { |
| entities_skeleton[i] for i in query.query_anchors if entities_skeleton[i] != -1 |
| } |
|
|
| rank_c = 0 |
| c_err = 0 |
| community_size = 0 |
| candidates = [] |
|
|
| for rank_0indexed in range(num_communities): |
| cid = int(community_order[rank_0indexed].item()) |
| c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0] |
| c_size = int(c_entities_tensor.shape[0]) |
| c_entities = [int(e.item()) for e in c_entities_tensor] |
|
|
| if any(e in valid_answers for e in c_entities): |
| rank_c = rank_0indexed + 1 |
| community_size = c_size |
| candidates = [e for e in c_entities if e not in anchor_entity_ids] |
| break |
|
|
| c_err += c_size |
|
|
| if not candidates: |
| raise InvalidRequestError( |
| "No entities in the knowledge graph satisfy this query" |
| ) |
|
|
| n_candidates = len(candidates) |
| candidates_tensor = pt.tensor(candidates, dtype=pt.long, device=device) |
|
|
| |
| |
| |
| e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], [] |
| for i in range(num_tree_nodes): |
| if entities_skeleton[i] == -1: |
| entities_i = candidates_tensor |
| else: |
| entities_i = pt.full([n_candidates], entities_skeleton[i], dtype=pt.long, device=device) |
| e_batch.append(entities_i) |
| x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float()) |
| c_batch.append(community_membership[entities_i]) |
| for j in range(num_tree_edges): |
| r_label = qi_skeleton.es[j]["r"] |
| if "p" in r_label: |
| r_id = int(r_label[1:]) |
| edge_attr_batch.append( |
| one_hot(pt.full([n_candidates], r_id, dtype=pt.long, device=device), num_relations + 1).float() |
| ) |
| else: |
| edge_attr_batch.append( |
| one_hot(pt.full([n_candidates], num_relations, dtype=pt.long, device=device), num_relations + 1).float() |
| ) |
|
|
| with pt.no_grad(): |
| batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch) |
| score_chunks = [] |
| for chunk in batched_query.batch_split(SCORING_MINI_BATCH_SIZE): |
| q_emb, a_emb = embedder(chunk) |
| score_chunks.append(link_ranker(q_emb, a_emb).view(-1)) |
| scores = pt.cat(score_chunks) |
|
|
| k = min(top_k, n_candidates) |
| top_scores, top_indices = scores.topk(k) |
| step2_ms = (time.perf_counter() - t2_start) * 1000.0 |
|
|
| |
| inv_nodes, _, _ = loader.dataset.get_inverted_name_maps() |
| predictions = [] |
| for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1): |
| entity_id = candidates[idx] |
| raw_name = str(inv_nodes.get(entity_id, entity_id)) |
| predictions.append({ |
| "intra_community_rank": intra_community_rank, |
| "rank": c_err + intra_community_rank, |
| "entity_id": entity_id, |
| "entity_name": clean_entity_name(raw_name, dataset_id), |
| "score": round(float(score), 4), |
| "is_valid_answer": entity_id in valid_answers, |
| }) |
|
|
| total_ms = step1_ms + step2_ms |
| |
| speedup = num_nodes / (num_communities + c_err + community_size) |
| return _build_response(dataset_id, algorithm, query_structure_id, anchors, |
| relations_map, qs_mapping, predictions, total_ms, |
| step1_ms, step2_ms, speedup, loader, |
| rank_c=rank_c, c_err=c_err) |
|
|
|
|
| def _resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping, |
| user_variables, adj_s_to_t): |
| """Fill in unspecified variable entities via get_node_cut_cache. |
| |
| Updates qi_skeleton and entities_skeleton in-place. |
| Runs up to 3 passes to handle 3p which has 2 variables. |
| """ |
| node_api_ids = {v: k for k, v in qs_mapping["nodes"].items()} |
|
|
| def _run_pass(): |
| cache = get_node_cut_cache(qi_skeleton, query, adj_s_to_t) |
| resolved_any = False |
| for tree_node_idx, candidates in cache.items(): |
| if tree_node_idx not in node_api_ids: |
| continue |
| api_id = node_api_ids[tree_node_idx] |
| if entities_skeleton[tree_node_idx] != -1: |
| continue |
| if api_id in user_variables: |
| entity = int(user_variables[api_id]) |
| elif not candidates: |
| raise InvalidRequestError( |
| f"No valid entity exists at variable '{api_id}' " |
| "for the given anchors and relations" |
| ) |
| else: |
| entity = int(np.random.choice(candidates)) |
| entities_skeleton[tree_node_idx] = entity |
| qi_skeleton.vs[tree_node_idx]["e"] = entity |
| resolved_any = True |
| return resolved_any |
|
|
| for _ in range(3): |
| if not _run_pass(): |
| break |
|
|
| |
| |
| |
| |
| |
| changed = True |
| while changed: |
| changed = False |
| for i in range(query.query_tree.vcount()): |
| if i == query.query_answer or entities_skeleton[i] != -1: |
| continue |
| in_edges = query.query_tree.es[query.query_tree.incident(i, mode="in")] |
| if not in_edges: |
| continue |
| in_edge = in_edges[0] |
| if in_edge["r"] == "i": |
| |
| parent = in_edge.source |
| if entities_skeleton[parent] != -1: |
| entities_skeleton[i] = entities_skeleton[parent] |
| qi_skeleton.vs[i]["e"] = entities_skeleton[i] |
| changed = True |
| elif "p" in in_edge["r"]: |
| |
| |
| |
| out_edges = qi_skeleton.es[qi_skeleton.incident(i, mode="out")] |
| if out_edges: |
| child = out_edges[0].target |
| if entities_skeleton[child] != -1: |
| rel_id = query_edge_r_to_int(out_edges[0]["r"]) |
| candidates = list(adj_s_to_t.get(entities_skeleton[child], {}).get(rel_id, [])) |
| if not candidates: |
| api_id = node_api_ids.get(i, str(i)) |
| raise InvalidRequestError( |
| f"No valid entity exists at variable '{api_id}' " |
| "for the given anchors and relations" |
| ) |
| entities_skeleton[i] = int(np.random.choice(candidates)) |
| qi_skeleton.vs[i]["e"] = entities_skeleton[i] |
| changed = True |
|
|
|
|
| def _build_response(dataset_id, algorithm, query_structure_id, anchors, relations_map, |
| qs_mapping, predictions, total_ms, step1_ms, step2_ms, speedup, loader, |
| rank_c=1, c_err=0): |
| """Assemble the CoinsPredictResponse dict.""" |
| inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps() |
|
|
| qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure_id) |
|
|
| |
| node_name = {} |
| for node in qs_template["nodes"]: |
| if node["type"] == "anchor": |
| raw = str(inv_nodes.get(anchors[node["id"]], anchors[node["id"]])) |
| node_name[node["id"]] = clean_entity_name(raw, dataset_id) |
| elif node["type"] == "variable": |
| node_name[node["id"]] = "(variable)" |
| else: |
| node_name[node["id"]] = "?" |
|
|
| |
| edge_parts = [] |
| for edge in qs_template["edges"]: |
| rel_id = relations_map[edge["id"]] |
| rel = clean_relation_name(str(inv_relations.get(rel_id, rel_id)), dataset_id) |
| edge_parts.append(f"{node_name[edge['source']]} --[{rel}]--> {node_name[edge['target']]}") |
| query_description = ", ".join(edge_parts) |
|
|
| return { |
| "dataset_id": dataset_id, |
| "algorithm": algorithm, |
| "query_structure": query_structure_id, |
| "query_description": query_description, |
| "predictions": predictions, |
| "timing": { |
| "step1_ms": round(step1_ms, 1), |
| "step1_label": "Community detection + localized embedding", |
| "step2_ms": round(step2_ms, 1), |
| "step2_label": "Link prediction", |
| "total_ms": round(total_ms, 1), |
| "rank_c": rank_c, |
| "baseline_estimate_ms": round(total_ms * speedup, 1), |
| "baseline_label": "Estimated baseline (without COINs)", |
| "speedup": round(speedup, 2), |
| }, |
| } |
|
|