File size: 15,283 Bytes
3ad32ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d03843
 
 
 
 
3ad32ba
 
 
 
 
 
 
 
 
 
 
 
 
 
c1b3cc7
3ad32ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1b3cc7
3ad32ba
 
 
 
 
 
c1b3cc7
3ad32ba
 
 
 
 
c1b3cc7
3ad32ba
 
 
 
c1b3cc7
 
 
3ad32ba
 
 
 
 
 
 
 
 
 
 
2d03843
 
 
 
 
3ad32ba
 
 
 
 
 
 
 
2d03843
 
 
 
3ad32ba
 
 
2d03843
3ad32ba
 
 
 
 
 
2d03843
 
3ad32ba
 
2d03843
3ad32ba
 
 
 
2d03843
3ad32ba
 
 
 
2d03843
 
 
 
 
 
3ad32ba
 
2d03843
 
 
 
3ad32ba
 
 
 
2d03843
3ad32ba
 
 
2d03843
3ad32ba
 
 
2d03843
3ad32ba
 
 
 
2d03843
 
 
 
 
3ad32ba
2d03843
3ad32ba
 
 
 
 
 
 
2d03843
3ad32ba
 
 
 
 
 
 
2d03843
3ad32ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""COINs link prediction / query answering β€” inference logic.

This module is intentionally free of Django settings and registry state.
It receives a fully-prepared Experiment object and runs the inference pipeline.
"""
import time

import numpy as np
import torch as pt
from torch.nn.functional import one_hot

from api.exceptions import InvalidRequestError
from api.services.constants import QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS, QUERY_STRUCTURES
from api.utils import clean_entity_name, clean_relation_name

# Research-code imports β€” available after settings.py adds the repos to sys.path.
from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int
from graph_completion.graphs.preprocess import QueryData

# Step-2 mini-batch cap: the link-ranker runs a full forward pass per candidate,
# and the hit community can hold thousands of nodes on Freebase. Splitting keeps
# CPU memory bounded without changing results.
SCORING_MINI_BATCH_SIZE = 512


def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
                        anchors, variables, relations_map, top_k):
    """Run a single COINs query and return the CoinsPredictResponse dict.

    experiment   β€” fully-prepared Experiment (embedder + link_ranker + loader loaded).
    anchors      β€” {api_node_id: entity_id} for all anchor nodes.
    variables    β€” {api_node_id: entity_id} for variable nodes (may be empty).
    relations_map β€” {api_edge_id: relation_id} for all edges.
    top_k        β€” number of top predictions to return (max 10).
    """
    embedder = experiment.embedder
    link_ranker = experiment.link_ranker
    loader = experiment.loader
    device = next(embedder.parameters()).device

    qs_mapping = QUERY_TREE_MAPPINGS[query_structure_id]
    num_nodes = loader.num_nodes
    num_relations = loader.num_relations
    num_node_types = loader.num_node_types
    num_communities = loader.num_communities
    community_membership = embedder.community_membership   # PT tensor [num_nodes], long
    node_types_tensor = embedder.node_types                # PT tensor [num_nodes], long

    # Full-KG adjacency (train + val + test) β€” graph_indexes cover train only,
    # which would miss answers in val/test splits.
    adj_s_to_t = experiment.full_adj_s_to_t

    # ---- Build Query tree ----
    query = Query(QUERY_STRUCTURE_INTERNAL[query_structure_id])
    query.build_query_tree()
    query.get_node_cut()

    num_tree_nodes = query.query_tree.vcount()
    num_tree_edges = query.query_tree.ecount()

    # ---- Build query_instance_mapped skeleton ----
    qi_skeleton = query.query_tree.copy()
    entities_skeleton = [-1] * num_tree_nodes  # -1 = unresolved; avoids confusion with entity 0
    for api_id, entity_id in anchors.items():
        tree_idx = qs_mapping["nodes"][api_id]
        entities_skeleton[tree_idx] = int(entity_id)
    qi_skeleton.vs["e"] = entities_skeleton
    for api_id, rel_id in relations_map.items():
        edge_idx = qs_mapping["edges"][api_id]
        qi_skeleton.es[edge_idx]["r"] = f"p{rel_id}"
    # intersection edges keep "i" from build_query_tree

    # ---- Resolve variable entities ----
    _resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping,
                       variables, adj_s_to_t)

    # ---- Step 1: Community scoring ----
    t1_start = time.perf_counter()

    all_communities = pt.arange(num_communities, dtype=pt.long, device=device)
    e_c, edge_attr_c = [], []
    for i in range(num_tree_nodes):
        if i == query.query_answer:
            e_c.append(all_communities)
        else:
            com = int(community_membership[entities_skeleton[i]].item())
            e_c.append(pt.full([num_communities], com, dtype=pt.long, device=device))
    for j in range(num_tree_edges):
        r_label = qi_skeleton.es[j]["r"]
        if "p" in r_label:
            r_id = int(r_label[1:])
            edge_attr_c.append(
                one_hot(pt.full([num_communities], r_id, dtype=pt.long, device=device), num_relations + 1).float()
            )
        else:
            # intersection edge β€” use the "no relation" slot (index num_relations)
            edge_attr_c.append(
                one_hot(
                    pt.full([num_communities], num_relations, dtype=pt.long, device=device), num_relations + 1
                ).float()
            )

    with pt.no_grad():
        community_query = QueryData(query, e_c=e_c, edge_attr_c=edge_attr_c)
        c_q_emb, c_a_emb = embedder.embed_communities(community_query)
        community_scores = link_ranker(c_q_emb, c_a_emb, for_communities=True)

    # Sort communities by descending score for the search loop
    community_order = community_scores.argsort(descending=True)  # [K]
    step1_ms = (time.perf_counter() - t1_start) * 1000.0

    # ---- Step 2: Score every entity in the hit community ----
    # Mirrors rank_samples (experiments.py:782-786): the hit community is the
    # first one in step-1 order that contains *any* KG-valid answer; within it
    # we score all entities (minus anchors) so the link ranker actually has to
    # discriminate, instead of being handed only the known answers.
    t2_start = time.perf_counter()

    valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t))
    if not valid_answers:
        raise InvalidRequestError(
            "No entities in the knowledge graph satisfy this query"
        )

    anchor_entity_ids = {
        entities_skeleton[i] for i in query.query_anchors if entities_skeleton[i] != -1
    }

    rank_c = 0           # 1-based rank of the hit community (0 = no hit)
    c_err = 0            # sum of sizes of communities with better step-1 score than the hit
    community_size = 0   # size of the hit community
    candidates = []

    for rank_0indexed in range(num_communities):
        cid = int(community_order[rank_0indexed].item())
        c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0]
        c_size = int(c_entities_tensor.shape[0])
        c_entities = [int(e.item()) for e in c_entities_tensor]

        if any(e in valid_answers for e in c_entities):
            rank_c = rank_0indexed + 1
            community_size = c_size
            candidates = [e for e in c_entities if e not in anchor_entity_ids]
            break

        c_err += c_size

    if not candidates:
        raise InvalidRequestError(
            "No entities in the knowledge graph satisfy this query"
        )

    n_candidates = len(candidates)
    candidates_tensor = pt.tensor(candidates, dtype=pt.long, device=device)

    # Build per-tree-node entity columns. Anchor / resolved-variable positions
    # repeat the same id across the batch; the answer and phantom-i positions
    # take each candidate.
    e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], []
    for i in range(num_tree_nodes):
        if entities_skeleton[i] == -1:
            entities_i = candidates_tensor
        else:
            entities_i = pt.full([n_candidates], entities_skeleton[i], dtype=pt.long, device=device)
        e_batch.append(entities_i)
        x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float())
        c_batch.append(community_membership[entities_i])
    for j in range(num_tree_edges):
        r_label = qi_skeleton.es[j]["r"]
        if "p" in r_label:
            r_id = int(r_label[1:])
            edge_attr_batch.append(
                one_hot(pt.full([n_candidates], r_id, dtype=pt.long, device=device), num_relations + 1).float()
            )
        else:
            edge_attr_batch.append(
                one_hot(pt.full([n_candidates], num_relations, dtype=pt.long, device=device), num_relations + 1).float()
            )

    with pt.no_grad():
        batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch)
        score_chunks = []
        for chunk in batched_query.batch_split(SCORING_MINI_BATCH_SIZE):
            q_emb, a_emb = embedder(chunk)
            score_chunks.append(link_ranker(q_emb, a_emb).view(-1))
        scores = pt.cat(score_chunks)

    k = min(top_k, n_candidates)
    top_scores, top_indices = scores.topk(k)
    step2_ms = (time.perf_counter() - t2_start) * 1000.0

    # rank = c_err + intra_community_rank  (exact rank_samples formula)
    inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
    predictions = []
    for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1):
        entity_id = candidates[idx]
        raw_name = str(inv_nodes.get(entity_id, entity_id))
        predictions.append({
            "intra_community_rank": intra_community_rank,
            "rank": c_err + intra_community_rank,
            "entity_id": entity_id,
            "entity_name": clean_entity_name(raw_name, dataset_id),
            "score": round(float(score), 4),
            "is_valid_answer": entity_id in valid_answers,
        })

    total_ms = step1_ms + step2_ms
    # Speedup accounts for the full search cost: K community scores + skipped communities + hit community
    speedup = num_nodes / (num_communities + c_err + community_size)
    return _build_response(dataset_id, algorithm, query_structure_id, anchors,
                           relations_map, qs_mapping, predictions, total_ms,
                           step1_ms, step2_ms, speedup, loader,
                           rank_c=rank_c, c_err=c_err)


def _resolve_variables(qi_skeleton, query, entities_skeleton, qs_mapping,
                       user_variables, adj_s_to_t):
    """Fill in unspecified variable entities via get_node_cut_cache.

    Updates qi_skeleton and entities_skeleton in-place.
    Runs up to 3 passes to handle 3p which has 2 variables.
    """
    node_api_ids = {v: k for k, v in qs_mapping["nodes"].items()}

    def _run_pass():
        cache = get_node_cut_cache(qi_skeleton, query, adj_s_to_t)
        resolved_any = False
        for tree_node_idx, candidates in cache.items():
            if tree_node_idx not in node_api_ids:
                continue
            api_id = node_api_ids[tree_node_idx]
            if entities_skeleton[tree_node_idx] != -1:
                continue  # anchor or already-resolved variable
            if api_id in user_variables:
                entity = int(user_variables[api_id])
            elif not candidates:
                raise InvalidRequestError(
                    f"No valid entity exists at variable '{api_id}' "
                    "for the given anchors and relations"
                )
            else:
                entity = int(np.random.choice(candidates))
            entities_skeleton[tree_node_idx] = entity
            qi_skeleton.vs[tree_node_idx]["e"] = entity
            resolved_any = True
        return resolved_any

    for _ in range(3):
        if not _run_pass():
            break

    # Post-processing: propagate resolved entities to nodes that get_node_cut_cache
    # can't reach (they lie between the cut and the anchors or answer in the tree).
    # Two cases:
    #  - Phantom "i" nodes (e.g. ip nodes 1,3): inherit their resolved ancestor's entity
    #  - Non-cut "p" variable nodes (e.g. 3p v1): resolve via adj_s_to_t[child_entity][rel]
    changed = True
    while changed:
        changed = False
        for i in range(query.query_tree.vcount()):
            if i == query.query_answer or entities_skeleton[i] != -1:
                continue
            in_edges = query.query_tree.es[query.query_tree.incident(i, mode="in")]
            if not in_edges:
                continue
            in_edge = in_edges[0]
            if in_edge["r"] == "i":
                # Phantom intersection node: copy the resolved parent (intersection var)
                parent = in_edge.source
                if entities_skeleton[parent] != -1:
                    entities_skeleton[i] = entities_skeleton[parent]
                    qi_skeleton.vs[i]["e"] = entities_skeleton[i]
                    changed = True
            elif "p" in in_edge["r"]:
                # Projection variable between cut and anchors: resolve from child.
                # Use qi_skeleton.es to get the concrete "p{rel_id}" label (not the
                # generic "p" from query.query_tree that would cause int("") errors).
                out_edges = qi_skeleton.es[qi_skeleton.incident(i, mode="out")]
                if out_edges:
                    child = out_edges[0].target
                    if entities_skeleton[child] != -1:
                        rel_id = query_edge_r_to_int(out_edges[0]["r"])
                        candidates = list(adj_s_to_t.get(entities_skeleton[child], {}).get(rel_id, []))
                        if not candidates:
                            api_id = node_api_ids.get(i, str(i))
                            raise InvalidRequestError(
                                f"No valid entity exists at variable '{api_id}' "
                                "for the given anchors and relations"
                            )
                        entities_skeleton[i] = int(np.random.choice(candidates))
                        qi_skeleton.vs[i]["e"] = entities_skeleton[i]
                        changed = True


def _build_response(dataset_id, algorithm, query_structure_id, anchors, relations_map,
                    qs_mapping, predictions, total_ms, step1_ms, step2_ms, speedup, loader,
                    rank_c=1, c_err=0):
    """Assemble the CoinsPredictResponse dict."""
    inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()

    qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure_id)

    # Build a name map for each node in the template
    node_name = {}
    for node in qs_template["nodes"]:
        if node["type"] == "anchor":
            raw = str(inv_nodes.get(anchors[node["id"]], anchors[node["id"]]))
            node_name[node["id"]] = clean_entity_name(raw, dataset_id)
        elif node["type"] == "variable":
            node_name[node["id"]] = "(variable)"
        else:
            node_name[node["id"]] = "?"

    # Describe each edge; comma-separate for a clean multi-pattern display
    edge_parts = []
    for edge in qs_template["edges"]:
        rel_id = relations_map[edge["id"]]
        rel = clean_relation_name(str(inv_relations.get(rel_id, rel_id)), dataset_id)
        edge_parts.append(f"{node_name[edge['source']]} --[{rel}]--> {node_name[edge['target']]}")
    query_description = ", ".join(edge_parts)

    return {
        "dataset_id": dataset_id,
        "algorithm": algorithm,
        "query_structure": query_structure_id,
        "query_description": query_description,
        "predictions": predictions,
        "timing": {
            "step1_ms": round(step1_ms, 1),
            "step1_label": "Community detection + localized embedding",
            "step2_ms": round(step2_ms, 1),
            "step2_label": "Link prediction",
            "total_ms": round(total_ms, 1),
            "rank_c": rank_c,
            "baseline_estimate_ms": round(total_ms * speedup, 1),
            "baseline_label": "Estimated baseline (without COINs)",
            "speedup": round(speedup, 2),
        },
    }