| from rest_framework.response import Response |
| from rest_framework.views import APIView |
|
|
| from api.exceptions import InvalidRequestError, NotFoundError |
| from api.services.constants import ( |
| COINS_DATASET_META, COINS_MODELS, QUERY_STRUCTURES, |
| QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS, |
| ) |
| from api.services.registry import ModelRegistry |
| from api.utils import clean_entity_name, clean_relation_name |
|
|
|
|
| def _require_loader(dataset_id): |
| """Validate dataset_id and ensure its Loader is available.""" |
| if dataset_id not in COINS_DATASET_META: |
| raise NotFoundError(f"Dataset '{dataset_id}' not found") |
| registry = ModelRegistry.get() |
| if registry.get_loader(dataset_id) is None: |
| raise NotFoundError(f"Dataset '{dataset_id}' data not loaded") |
| return registry |
|
|
|
|
| class CoinsDatasetsView(APIView): |
| def get(self, request): |
| registry = ModelRegistry.get() |
| datasets = [] |
| for dataset_id, meta in COINS_DATASET_META.items(): |
| datasets.append({ |
| "id": dataset_id, |
| "name": meta["name"], |
| "num_entities": registry.get_entity_count(dataset_id), |
| "num_relations": registry.get_relation_count(dataset_id), |
| "description": meta["description"], |
| }) |
| return Response({"datasets": datasets}) |
|
|
|
|
| class CoinsEntitiesView(APIView): |
| def get(self, request, dataset_id): |
| registry = _require_loader(dataset_id) |
| q = request.query_params.get("q", None) |
| page = int(request.query_params.get("page", 1)) |
| page_size = int(request.query_params.get("page_size", 50)) |
| page_size = max(1, min(200, page_size)) |
|
|
| page_items, total = registry.search_entities(dataset_id, q, page, page_size) |
|
|
| return Response({ |
| "dataset_id": dataset_id, |
| "total": total, |
| "page": page, |
| "page_size": page_size, |
| "entities": [ |
| {"id": eid, "name": name, "label": clean_entity_name(name, dataset_id)} |
| for eid, name in page_items |
| ], |
| }) |
|
|
|
|
| class CoinsRelationsView(APIView): |
| def get(self, request, dataset_id): |
| registry = _require_loader(dataset_id) |
| q = request.query_params.get("q", None) |
| page = int(request.query_params.get("page", 1)) |
| page_size = int(request.query_params.get("page_size", 50)) |
| page_size = max(1, min(200, page_size)) |
|
|
| page_items, total = registry.search_relations(dataset_id, q, page, page_size) |
|
|
| return Response({ |
| "dataset_id": dataset_id, |
| "total": total, |
| "page": page, |
| "page_size": page_size, |
| "relations": [ |
| {"id": rid, "name": name, "label": clean_relation_name(name, dataset_id)} |
| for rid, name in page_items |
| ], |
| }) |
|
|
|
|
| class CoinsSampleTriplesView(APIView): |
| def get(self, request, dataset_id): |
| registry = _require_loader(dataset_id) |
| count = int(request.query_params.get("count", 10)) |
| count = max(1, min(50, count)) |
|
|
| seed_raw = request.query_params.get("seed") |
| seed = seed_raw if seed_raw not in (None, "") else None |
|
|
| return Response({ |
| "dataset_id": dataset_id, |
| "triples": registry.sample_triples(dataset_id, count, seed=seed), |
| }) |
|
|
|
|
| class CoinsSampleQueryView(APIView): |
| def get(self, request, dataset_id): |
| registry = _require_loader(dataset_id) |
| query_structure = request.query_params.get("query_structure") |
| if not query_structure: |
| raise InvalidRequestError("Missing required parameter: query_structure") |
| valid_qs = {qs["id"] for qs in QUERY_STRUCTURES} |
| if query_structure not in valid_qs: |
| raise InvalidRequestError( |
| f"Unknown query_structure '{query_structure}'. Must be one of: {sorted(valid_qs)}" |
| ) |
| count = int(request.query_params.get("count", 1)) |
| count = max(1, min(10, count)) |
| seed_raw = request.query_params.get("seed") |
| seed = seed_raw if seed_raw not in (None, "") else None |
|
|
| queries = registry.sample_query(dataset_id, query_structure, count, seed=seed) |
| return Response({ |
| "dataset_id": dataset_id, |
| "query_structure": query_structure, |
| "queries": queries, |
| }) |
|
|
|
|
| class CoinsModelsView(APIView): |
| def get(self, request): |
| registry = ModelRegistry.get() |
| models = [] |
| for model in COINS_MODELS: |
| available_datasets = [] |
| for dataset_id in COINS_DATASET_META: |
| algos = registry.coins_checkpoints_available.get(dataset_id, []) |
| if model["algorithm"] in algos: |
| available_datasets.append(dataset_id) |
| models.append({ |
| "algorithm": model["algorithm"], |
| "name": model["name"], |
| "description": model["description"], |
| "supported_query_structures": model["supported_query_structures"], |
| "available_datasets": available_datasets, |
| }) |
| return Response({"models": models}) |
|
|
|
|
| class CoinsQueryStructuresView(APIView): |
| def get(self, request): |
| return Response({"query_structures": QUERY_STRUCTURES}) |
|
|
|
|
| class CoinsPredictView(APIView): |
| def post(self, request): |
| from api.exceptions import InferenceBusy, ModelUnavailable, InferenceError |
|
|
| data = request.data |
|
|
| |
| dataset_id = data.get("dataset_id") |
| algorithm = data.get("algorithm") |
| query_structure = data.get("query_structure") |
| anchors = data.get("anchors") |
| relations = data.get("relations") |
| variables = data.get("variables") or {} |
| top_k = int(data.get("top_k", 10)) |
| top_k = max(1, min(10, top_k)) |
|
|
| |
| if not all([dataset_id, algorithm, query_structure, anchors is not None, relations is not None]): |
| raise InvalidRequestError( |
| "Missing required field(s): dataset_id, algorithm, query_structure, anchors, relations" |
| ) |
|
|
| if dataset_id not in COINS_DATASET_META: |
| raise NotFoundError(f"Dataset '{dataset_id}' not found") |
|
|
| valid_algorithms = [m["algorithm"] for m in COINS_MODELS] |
| if algorithm not in valid_algorithms: |
| raise InvalidRequestError(f"Unknown algorithm '{algorithm}'") |
|
|
| if query_structure not in QUERY_STRUCTURE_INTERNAL: |
| raise InvalidRequestError(f"Unknown query structure '{query_structure}'") |
|
|
| |
| algo_model = next(m for m in COINS_MODELS if m["algorithm"] == algorithm) |
| if query_structure not in algo_model["supported_query_structures"]: |
| raise InvalidRequestError( |
| f"Algorithm '{algorithm}' does not support query structure '{query_structure}'" |
| ) |
|
|
| |
| registry = ModelRegistry.get() |
| algo_available = registry.coins_checkpoints_available.get(dataset_id, []) |
| if algorithm not in algo_available: |
| raise ModelUnavailable( |
| f"Model for dataset '{dataset_id}' with algorithm '{algorithm}' is not loaded" |
| ) |
|
|
| qs_mapping = QUERY_TREE_MAPPINGS[query_structure] |
|
|
| |
| qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure) |
| anchor_node_ids = {n["id"] for n in qs_template["nodes"] if n["type"] == "anchor"} |
| variable_node_ids = {n["id"] for n in qs_template["nodes"] if n["type"] == "variable"} |
| edge_ids = {e["id"] for e in qs_template["edges"]} |
|
|
| if set(anchors.keys()) != anchor_node_ids: |
| raise InvalidRequestError( |
| f"anchors keys {set(anchors.keys())} must exactly match anchor nodes {anchor_node_ids}" |
| ) |
| if set(relations.keys()) != edge_ids: |
| raise InvalidRequestError( |
| f"relations keys {set(relations.keys())} must exactly match edge IDs {edge_ids}" |
| ) |
| if not set(variables.keys()).issubset(variable_node_ids): |
| raise InvalidRequestError( |
| f"variables keys {set(variables.keys())} must be a subset of variable nodes {variable_node_ids}" |
| ) |
|
|
| |
| num_entities = registry.get_entity_count(dataset_id) |
| num_rels = registry.get_relation_count(dataset_id) |
| for api_id, eid in {**anchors, **variables}.items(): |
| if not (0 <= int(eid) < num_entities): |
| raise InvalidRequestError( |
| f"Entity ID {eid} at node '{api_id}' out of range [0, {num_entities})" |
| ) |
| for api_id, rid in relations.items(): |
| if not (0 <= int(rid) < num_rels): |
| raise InvalidRequestError( |
| f"Relation ID {rid} at edge '{api_id}' out of range [0, {num_rels})" |
| ) |
|
|
| |
| anchors = {k: int(v) for k, v in anchors.items()} |
| variables = {k: int(v) for k, v in variables.items()} |
| relations = {k: int(v) for k, v in relations.items()} |
|
|
| try: |
| result = registry.coins_predict( |
| dataset_id, algorithm, query_structure, |
| anchors, variables, relations, top_k, |
| ) |
| except InferenceBusy: |
| raise |
| except InvalidRequestError: |
| raise |
| except ModelUnavailable: |
| raise |
| except Exception as exc: |
| raise InferenceError(f"Inference failed: {exc}") from exc |
|
|
| return Response(result) |
|
|