from rest_framework.response import Response from rest_framework.views import APIView from api.exceptions import InvalidRequestError, NotFoundError from api.services.constants import ( COINS_DATASET_META, COINS_MODELS, QUERY_STRUCTURES, QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS, ) from api.services.registry import ModelRegistry from api.utils import clean_entity_name, clean_relation_name def _require_loader(dataset_id): """Validate dataset_id and ensure its Loader is available.""" if dataset_id not in COINS_DATASET_META: raise NotFoundError(f"Dataset '{dataset_id}' not found") registry = ModelRegistry.get() if registry.get_loader(dataset_id) is None: raise NotFoundError(f"Dataset '{dataset_id}' data not loaded") return registry class CoinsDatasetsView(APIView): def get(self, request): registry = ModelRegistry.get() datasets = [] for dataset_id, meta in COINS_DATASET_META.items(): datasets.append({ "id": dataset_id, "name": meta["name"], "num_entities": registry.get_entity_count(dataset_id), "num_relations": registry.get_relation_count(dataset_id), "description": meta["description"], }) return Response({"datasets": datasets}) class CoinsEntitiesView(APIView): def get(self, request, dataset_id): registry = _require_loader(dataset_id) q = request.query_params.get("q", None) page = int(request.query_params.get("page", 1)) page_size = int(request.query_params.get("page_size", 50)) page_size = max(1, min(200, page_size)) page_items, total = registry.search_entities(dataset_id, q, page, page_size) return Response({ "dataset_id": dataset_id, "total": total, "page": page, "page_size": page_size, "entities": [ {"id": eid, "name": name, "label": clean_entity_name(name, dataset_id)} for eid, name in page_items ], }) class CoinsRelationsView(APIView): def get(self, request, dataset_id): registry = _require_loader(dataset_id) q = request.query_params.get("q", None) page = int(request.query_params.get("page", 1)) page_size = int(request.query_params.get("page_size", 50)) page_size = max(1, min(200, page_size)) page_items, total = registry.search_relations(dataset_id, q, page, page_size) return Response({ "dataset_id": dataset_id, "total": total, "page": page, "page_size": page_size, "relations": [ {"id": rid, "name": name, "label": clean_relation_name(name, dataset_id)} for rid, name in page_items ], }) class CoinsSampleTriplesView(APIView): def get(self, request, dataset_id): registry = _require_loader(dataset_id) count = int(request.query_params.get("count", 10)) count = max(1, min(50, count)) seed_raw = request.query_params.get("seed") seed = seed_raw if seed_raw not in (None, "") else None return Response({ "dataset_id": dataset_id, "triples": registry.sample_triples(dataset_id, count, seed=seed), }) class CoinsSampleQueryView(APIView): def get(self, request, dataset_id): registry = _require_loader(dataset_id) query_structure = request.query_params.get("query_structure") if not query_structure: raise InvalidRequestError("Missing required parameter: query_structure") valid_qs = {qs["id"] for qs in QUERY_STRUCTURES} if query_structure not in valid_qs: raise InvalidRequestError( f"Unknown query_structure '{query_structure}'. Must be one of: {sorted(valid_qs)}" ) count = int(request.query_params.get("count", 1)) count = max(1, min(10, count)) seed_raw = request.query_params.get("seed") seed = seed_raw if seed_raw not in (None, "") else None queries = registry.sample_query(dataset_id, query_structure, count, seed=seed) return Response({ "dataset_id": dataset_id, "query_structure": query_structure, "queries": queries, }) class CoinsModelsView(APIView): def get(self, request): registry = ModelRegistry.get() models = [] for model in COINS_MODELS: available_datasets = [] for dataset_id in COINS_DATASET_META: algos = registry.coins_checkpoints_available.get(dataset_id, []) if model["algorithm"] in algos: available_datasets.append(dataset_id) models.append({ "algorithm": model["algorithm"], "name": model["name"], "description": model["description"], "supported_query_structures": model["supported_query_structures"], "available_datasets": available_datasets, }) return Response({"models": models}) class CoinsQueryStructuresView(APIView): def get(self, request): return Response({"query_structures": QUERY_STRUCTURES}) class CoinsPredictView(APIView): def post(self, request): from api.exceptions import InferenceBusy, ModelUnavailable, InferenceError data = request.data # --- Parse required fields --- dataset_id = data.get("dataset_id") algorithm = data.get("algorithm") query_structure = data.get("query_structure") anchors = data.get("anchors") relations = data.get("relations") variables = data.get("variables") or {} top_k = int(data.get("top_k", 10)) top_k = max(1, min(10, top_k)) # --- Validate required fields --- if not all([dataset_id, algorithm, query_structure, anchors is not None, relations is not None]): raise InvalidRequestError( "Missing required field(s): dataset_id, algorithm, query_structure, anchors, relations" ) if dataset_id not in COINS_DATASET_META: raise NotFoundError(f"Dataset '{dataset_id}' not found") valid_algorithms = [m["algorithm"] for m in COINS_MODELS] if algorithm not in valid_algorithms: raise InvalidRequestError(f"Unknown algorithm '{algorithm}'") if query_structure not in QUERY_STRUCTURE_INTERNAL: raise InvalidRequestError(f"Unknown query structure '{query_structure}'") # Check algorithm supports query_structure algo_model = next(m for m in COINS_MODELS if m["algorithm"] == algorithm) if query_structure not in algo_model["supported_query_structures"]: raise InvalidRequestError( f"Algorithm '{algorithm}' does not support query structure '{query_structure}'" ) # Check checkpoint available registry = ModelRegistry.get() algo_available = registry.coins_checkpoints_available.get(dataset_id, []) if algorithm not in algo_available: raise ModelUnavailable( f"Model for dataset '{dataset_id}' with algorithm '{algorithm}' is not loaded" ) qs_mapping = QUERY_TREE_MAPPINGS[query_structure] # Find anchor and variable node IDs from template qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure) anchor_node_ids = {n["id"] for n in qs_template["nodes"] if n["type"] == "anchor"} variable_node_ids = {n["id"] for n in qs_template["nodes"] if n["type"] == "variable"} edge_ids = {e["id"] for e in qs_template["edges"]} if set(anchors.keys()) != anchor_node_ids: raise InvalidRequestError( f"anchors keys {set(anchors.keys())} must exactly match anchor nodes {anchor_node_ids}" ) if set(relations.keys()) != edge_ids: raise InvalidRequestError( f"relations keys {set(relations.keys())} must exactly match edge IDs {edge_ids}" ) if not set(variables.keys()).issubset(variable_node_ids): raise InvalidRequestError( f"variables keys {set(variables.keys())} must be a subset of variable nodes {variable_node_ids}" ) # Validate entity IDs in range num_entities = registry.get_entity_count(dataset_id) num_rels = registry.get_relation_count(dataset_id) for api_id, eid in {**anchors, **variables}.items(): if not (0 <= int(eid) < num_entities): raise InvalidRequestError( f"Entity ID {eid} at node '{api_id}' out of range [0, {num_entities})" ) for api_id, rid in relations.items(): if not (0 <= int(rid) < num_rels): raise InvalidRequestError( f"Relation ID {rid} at edge '{api_id}' out of range [0, {num_rels})" ) # Convert to int anchors = {k: int(v) for k, v in anchors.items()} variables = {k: int(v) for k, v in variables.items()} relations = {k: int(v) for k, v in relations.items()} try: result = registry.coins_predict( dataset_id, algorithm, query_structure, anchors, variables, relations, top_k, ) except InferenceBusy: raise except InvalidRequestError: raise except ModelUnavailable: raise except Exception as exc: raise InferenceError(f"Inference failed: {exc}") from exc return Response(result)