File size: 9,797 Bytes
4f1e196
 
 
3ad32ba
 
 
 
 
4f1e196
db54566
4f1e196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db54566
 
 
 
4f1e196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db54566
 
 
 
4f1e196
 
 
 
 
 
 
 
 
ee3e8fe
 
 
4f1e196
 
ee3e8fe
4f1e196
 
 
db54566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f1e196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ad32ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from rest_framework.response import Response
from rest_framework.views import APIView

from api.exceptions import InvalidRequestError, NotFoundError
from api.services.constants import (
    COINS_DATASET_META, COINS_MODELS, QUERY_STRUCTURES,
    QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS,
)
from api.services.registry import ModelRegistry
from api.utils import clean_entity_name, clean_relation_name


def _require_loader(dataset_id):
    """Validate dataset_id and ensure its Loader is available."""
    if dataset_id not in COINS_DATASET_META:
        raise NotFoundError(f"Dataset '{dataset_id}' not found")
    registry = ModelRegistry.get()
    if registry.get_loader(dataset_id) is None:
        raise NotFoundError(f"Dataset '{dataset_id}' data not loaded")
    return registry


class CoinsDatasetsView(APIView):
    def get(self, request):
        registry = ModelRegistry.get()
        datasets = []
        for dataset_id, meta in COINS_DATASET_META.items():
            datasets.append({
                "id": dataset_id,
                "name": meta["name"],
                "num_entities": registry.get_entity_count(dataset_id),
                "num_relations": registry.get_relation_count(dataset_id),
                "description": meta["description"],
            })
        return Response({"datasets": datasets})


class CoinsEntitiesView(APIView):
    def get(self, request, dataset_id):
        registry = _require_loader(dataset_id)
        q = request.query_params.get("q", None)
        page = int(request.query_params.get("page", 1))
        page_size = int(request.query_params.get("page_size", 50))
        page_size = max(1, min(200, page_size))

        page_items, total = registry.search_entities(dataset_id, q, page, page_size)

        return Response({
            "dataset_id": dataset_id,
            "total": total,
            "page": page,
            "page_size": page_size,
            "entities": [
                {"id": eid, "name": name, "label": clean_entity_name(name, dataset_id)}
                for eid, name in page_items
            ],
        })


class CoinsRelationsView(APIView):
    def get(self, request, dataset_id):
        registry = _require_loader(dataset_id)
        q = request.query_params.get("q", None)
        page = int(request.query_params.get("page", 1))
        page_size = int(request.query_params.get("page_size", 50))
        page_size = max(1, min(200, page_size))

        page_items, total = registry.search_relations(dataset_id, q, page, page_size)

        return Response({
            "dataset_id": dataset_id,
            "total": total,
            "page": page,
            "page_size": page_size,
            "relations": [
                {"id": rid, "name": name, "label": clean_relation_name(name, dataset_id)}
                for rid, name in page_items
            ],
        })


class CoinsSampleTriplesView(APIView):
    def get(self, request, dataset_id):
        registry = _require_loader(dataset_id)
        count = int(request.query_params.get("count", 10))
        count = max(1, min(50, count))

        seed_raw = request.query_params.get("seed")
        seed = seed_raw if seed_raw not in (None, "") else None

        return Response({
            "dataset_id": dataset_id,
            "triples": registry.sample_triples(dataset_id, count, seed=seed),
        })


class CoinsSampleQueryView(APIView):
    def get(self, request, dataset_id):
        registry = _require_loader(dataset_id)
        query_structure = request.query_params.get("query_structure")
        if not query_structure:
            raise InvalidRequestError("Missing required parameter: query_structure")
        valid_qs = {qs["id"] for qs in QUERY_STRUCTURES}
        if query_structure not in valid_qs:
            raise InvalidRequestError(
                f"Unknown query_structure '{query_structure}'. Must be one of: {sorted(valid_qs)}"
            )
        count = int(request.query_params.get("count", 1))
        count = max(1, min(10, count))
        seed_raw = request.query_params.get("seed")
        seed = seed_raw if seed_raw not in (None, "") else None

        queries = registry.sample_query(dataset_id, query_structure, count, seed=seed)
        return Response({
            "dataset_id": dataset_id,
            "query_structure": query_structure,
            "queries": queries,
        })


class CoinsModelsView(APIView):
    def get(self, request):
        registry = ModelRegistry.get()
        models = []
        for model in COINS_MODELS:
            available_datasets = []
            for dataset_id in COINS_DATASET_META:
                algos = registry.coins_checkpoints_available.get(dataset_id, [])
                if model["algorithm"] in algos:
                    available_datasets.append(dataset_id)
            models.append({
                "algorithm": model["algorithm"],
                "name": model["name"],
                "description": model["description"],
                "supported_query_structures": model["supported_query_structures"],
                "available_datasets": available_datasets,
            })
        return Response({"models": models})


class CoinsQueryStructuresView(APIView):
    def get(self, request):
        return Response({"query_structures": QUERY_STRUCTURES})


class CoinsPredictView(APIView):
    def post(self, request):
        from api.exceptions import InferenceBusy, ModelUnavailable, InferenceError

        data = request.data

        # --- Parse required fields ---
        dataset_id = data.get("dataset_id")
        algorithm = data.get("algorithm")
        query_structure = data.get("query_structure")
        anchors = data.get("anchors")
        relations = data.get("relations")
        variables = data.get("variables") or {}
        top_k = int(data.get("top_k", 10))
        top_k = max(1, min(10, top_k))

        # --- Validate required fields ---
        if not all([dataset_id, algorithm, query_structure, anchors is not None, relations is not None]):
            raise InvalidRequestError(
                "Missing required field(s): dataset_id, algorithm, query_structure, anchors, relations"
            )

        if dataset_id not in COINS_DATASET_META:
            raise NotFoundError(f"Dataset '{dataset_id}' not found")

        valid_algorithms = [m["algorithm"] for m in COINS_MODELS]
        if algorithm not in valid_algorithms:
            raise InvalidRequestError(f"Unknown algorithm '{algorithm}'")

        if query_structure not in QUERY_STRUCTURE_INTERNAL:
            raise InvalidRequestError(f"Unknown query structure '{query_structure}'")

        # Check algorithm supports query_structure
        algo_model = next(m for m in COINS_MODELS if m["algorithm"] == algorithm)
        if query_structure not in algo_model["supported_query_structures"]:
            raise InvalidRequestError(
                f"Algorithm '{algorithm}' does not support query structure '{query_structure}'"
            )

        # Check checkpoint available
        registry = ModelRegistry.get()
        algo_available = registry.coins_checkpoints_available.get(dataset_id, [])
        if algorithm not in algo_available:
            raise ModelUnavailable(
                f"Model for dataset '{dataset_id}' with algorithm '{algorithm}' is not loaded"
            )

        qs_mapping = QUERY_TREE_MAPPINGS[query_structure]

        # Find anchor and variable node IDs from template
        qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure)
        anchor_node_ids = {n["id"] for n in qs_template["nodes"] if n["type"] == "anchor"}
        variable_node_ids = {n["id"] for n in qs_template["nodes"] if n["type"] == "variable"}
        edge_ids = {e["id"] for e in qs_template["edges"]}

        if set(anchors.keys()) != anchor_node_ids:
            raise InvalidRequestError(
                f"anchors keys {set(anchors.keys())} must exactly match anchor nodes {anchor_node_ids}"
            )
        if set(relations.keys()) != edge_ids:
            raise InvalidRequestError(
                f"relations keys {set(relations.keys())} must exactly match edge IDs {edge_ids}"
            )
        if not set(variables.keys()).issubset(variable_node_ids):
            raise InvalidRequestError(
                f"variables keys {set(variables.keys())} must be a subset of variable nodes {variable_node_ids}"
            )

        # Validate entity IDs in range
        num_entities = registry.get_entity_count(dataset_id)
        num_rels = registry.get_relation_count(dataset_id)
        for api_id, eid in {**anchors, **variables}.items():
            if not (0 <= int(eid) < num_entities):
                raise InvalidRequestError(
                    f"Entity ID {eid} at node '{api_id}' out of range [0, {num_entities})"
                )
        for api_id, rid in relations.items():
            if not (0 <= int(rid) < num_rels):
                raise InvalidRequestError(
                    f"Relation ID {rid} at edge '{api_id}' out of range [0, {num_rels})"
                )

        # Convert to int
        anchors = {k: int(v) for k, v in anchors.items()}
        variables = {k: int(v) for k, v in variables.items()}
        relations = {k: int(v) for k, v in relations.items()}

        try:
            result = registry.coins_predict(
                dataset_id, algorithm, query_structure,
                anchors, variables, relations, top_k,
            )
        except InferenceBusy:
            raise
        except InvalidRequestError:
            raise
        except ModelUnavailable:
            raise
        except Exception as exc:
            raise InferenceError(f"Inference failed: {exc}") from exc

        return Response(result)