website / src /backend /api /views /coins.py
Andrej Janchevski
feat(coins): add sample-query endpoint and resolve Freebase MIDs to names
db54566
from rest_framework.response import Response
from rest_framework.views import APIView
from api.exceptions import InvalidRequestError, NotFoundError
from api.services.constants import (
COINS_DATASET_META, COINS_MODELS, QUERY_STRUCTURES,
QUERY_STRUCTURE_INTERNAL, QUERY_TREE_MAPPINGS,
)
from api.services.registry import ModelRegistry
from api.utils import clean_entity_name, clean_relation_name
def _require_loader(dataset_id):
"""Validate dataset_id and ensure its Loader is available."""
if dataset_id not in COINS_DATASET_META:
raise NotFoundError(f"Dataset '{dataset_id}' not found")
registry = ModelRegistry.get()
if registry.get_loader(dataset_id) is None:
raise NotFoundError(f"Dataset '{dataset_id}' data not loaded")
return registry
class CoinsDatasetsView(APIView):
def get(self, request):
registry = ModelRegistry.get()
datasets = []
for dataset_id, meta in COINS_DATASET_META.items():
datasets.append({
"id": dataset_id,
"name": meta["name"],
"num_entities": registry.get_entity_count(dataset_id),
"num_relations": registry.get_relation_count(dataset_id),
"description": meta["description"],
})
return Response({"datasets": datasets})
class CoinsEntitiesView(APIView):
def get(self, request, dataset_id):
registry = _require_loader(dataset_id)
q = request.query_params.get("q", None)
page = int(request.query_params.get("page", 1))
page_size = int(request.query_params.get("page_size", 50))
page_size = max(1, min(200, page_size))
page_items, total = registry.search_entities(dataset_id, q, page, page_size)
return Response({
"dataset_id": dataset_id,
"total": total,
"page": page,
"page_size": page_size,
"entities": [
{"id": eid, "name": name, "label": clean_entity_name(name, dataset_id)}
for eid, name in page_items
],
})
class CoinsRelationsView(APIView):
def get(self, request, dataset_id):
registry = _require_loader(dataset_id)
q = request.query_params.get("q", None)
page = int(request.query_params.get("page", 1))
page_size = int(request.query_params.get("page_size", 50))
page_size = max(1, min(200, page_size))
page_items, total = registry.search_relations(dataset_id, q, page, page_size)
return Response({
"dataset_id": dataset_id,
"total": total,
"page": page,
"page_size": page_size,
"relations": [
{"id": rid, "name": name, "label": clean_relation_name(name, dataset_id)}
for rid, name in page_items
],
})
class CoinsSampleTriplesView(APIView):
def get(self, request, dataset_id):
registry = _require_loader(dataset_id)
count = int(request.query_params.get("count", 10))
count = max(1, min(50, count))
seed_raw = request.query_params.get("seed")
seed = seed_raw if seed_raw not in (None, "") else None
return Response({
"dataset_id": dataset_id,
"triples": registry.sample_triples(dataset_id, count, seed=seed),
})
class CoinsSampleQueryView(APIView):
def get(self, request, dataset_id):
registry = _require_loader(dataset_id)
query_structure = request.query_params.get("query_structure")
if not query_structure:
raise InvalidRequestError("Missing required parameter: query_structure")
valid_qs = {qs["id"] for qs in QUERY_STRUCTURES}
if query_structure not in valid_qs:
raise InvalidRequestError(
f"Unknown query_structure '{query_structure}'. Must be one of: {sorted(valid_qs)}"
)
count = int(request.query_params.get("count", 1))
count = max(1, min(10, count))
seed_raw = request.query_params.get("seed")
seed = seed_raw if seed_raw not in (None, "") else None
queries = registry.sample_query(dataset_id, query_structure, count, seed=seed)
return Response({
"dataset_id": dataset_id,
"query_structure": query_structure,
"queries": queries,
})
class CoinsModelsView(APIView):
def get(self, request):
registry = ModelRegistry.get()
models = []
for model in COINS_MODELS:
available_datasets = []
for dataset_id in COINS_DATASET_META:
algos = registry.coins_checkpoints_available.get(dataset_id, [])
if model["algorithm"] in algos:
available_datasets.append(dataset_id)
models.append({
"algorithm": model["algorithm"],
"name": model["name"],
"description": model["description"],
"supported_query_structures": model["supported_query_structures"],
"available_datasets": available_datasets,
})
return Response({"models": models})
class CoinsQueryStructuresView(APIView):
def get(self, request):
return Response({"query_structures": QUERY_STRUCTURES})
class CoinsPredictView(APIView):
def post(self, request):
from api.exceptions import InferenceBusy, ModelUnavailable, InferenceError
data = request.data
# --- Parse required fields ---
dataset_id = data.get("dataset_id")
algorithm = data.get("algorithm")
query_structure = data.get("query_structure")
anchors = data.get("anchors")
relations = data.get("relations")
variables = data.get("variables") or {}
top_k = int(data.get("top_k", 10))
top_k = max(1, min(10, top_k))
# --- Validate required fields ---
if not all([dataset_id, algorithm, query_structure, anchors is not None, relations is not None]):
raise InvalidRequestError(
"Missing required field(s): dataset_id, algorithm, query_structure, anchors, relations"
)
if dataset_id not in COINS_DATASET_META:
raise NotFoundError(f"Dataset '{dataset_id}' not found")
valid_algorithms = [m["algorithm"] for m in COINS_MODELS]
if algorithm not in valid_algorithms:
raise InvalidRequestError(f"Unknown algorithm '{algorithm}'")
if query_structure not in QUERY_STRUCTURE_INTERNAL:
raise InvalidRequestError(f"Unknown query structure '{query_structure}'")
# Check algorithm supports query_structure
algo_model = next(m for m in COINS_MODELS if m["algorithm"] == algorithm)
if query_structure not in algo_model["supported_query_structures"]:
raise InvalidRequestError(
f"Algorithm '{algorithm}' does not support query structure '{query_structure}'"
)
# Check checkpoint available
registry = ModelRegistry.get()
algo_available = registry.coins_checkpoints_available.get(dataset_id, [])
if algorithm not in algo_available:
raise ModelUnavailable(
f"Model for dataset '{dataset_id}' with algorithm '{algorithm}' is not loaded"
)
qs_mapping = QUERY_TREE_MAPPINGS[query_structure]
# Find anchor and variable node IDs from template
qs_template = next(qs for qs in QUERY_STRUCTURES if qs["id"] == query_structure)
anchor_node_ids = {n["id"] for n in qs_template["nodes"] if n["type"] == "anchor"}
variable_node_ids = {n["id"] for n in qs_template["nodes"] if n["type"] == "variable"}
edge_ids = {e["id"] for e in qs_template["edges"]}
if set(anchors.keys()) != anchor_node_ids:
raise InvalidRequestError(
f"anchors keys {set(anchors.keys())} must exactly match anchor nodes {anchor_node_ids}"
)
if set(relations.keys()) != edge_ids:
raise InvalidRequestError(
f"relations keys {set(relations.keys())} must exactly match edge IDs {edge_ids}"
)
if not set(variables.keys()).issubset(variable_node_ids):
raise InvalidRequestError(
f"variables keys {set(variables.keys())} must be a subset of variable nodes {variable_node_ids}"
)
# Validate entity IDs in range
num_entities = registry.get_entity_count(dataset_id)
num_rels = registry.get_relation_count(dataset_id)
for api_id, eid in {**anchors, **variables}.items():
if not (0 <= int(eid) < num_entities):
raise InvalidRequestError(
f"Entity ID {eid} at node '{api_id}' out of range [0, {num_entities})"
)
for api_id, rid in relations.items():
if not (0 <= int(rid) < num_rels):
raise InvalidRequestError(
f"Relation ID {rid} at edge '{api_id}' out of range [0, {num_rels})"
)
# Convert to int
anchors = {k: int(v) for k, v in anchors.items()}
variables = {k: int(v) for k, v in variables.items()}
relations = {k: int(v) for k, v in relations.items()}
try:
result = registry.coins_predict(
dataset_id, algorithm, query_structure,
anchors, variables, relations, top_k,
)
except InferenceBusy:
raise
except InvalidRequestError:
raise
except ModelUnavailable:
raise
except Exception as exc:
raise InferenceError(f"Inference failed: {exc}") from exc
return Response(result)