website / src /backend /api /views /graph_generation.py
Andrej Janchevski
feat(multiproxan): stream graph-generation progress and refine comm20 render
713aa17
import json
from django.http import StreamingHttpResponse
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response
from rest_framework.views import APIView
from api.exceptions import InvalidRequestError, ModelUnavailable
from api.renderers import EventStreamRenderer
from api.services.constants import GRAPHGEN_DATASETS, GRAPHGEN_SAMPLING_MODES
from api.services.registry import ModelRegistry
class GraphGenDatasetsView(APIView):
def get(self, request):
registry = ModelRegistry.get()
datasets = []
for dataset_id, meta in GRAPHGEN_DATASETS.items():
available_types = registry.graphgen_checkpoints_available.get(dataset_id, [])
datasets.append({
"id": dataset_id,
"name": meta["name"],
"type": meta["type"],
"description": meta["description"],
"node_types": meta["node_types"],
"edge_types": meta["edge_types"],
"max_nodes": meta["max_nodes"],
"available_model_types": available_types,
})
return Response({"datasets": datasets})
class GraphGenSamplingModesView(APIView):
def get(self, request):
return Response({"sampling_modes": GRAPHGEN_SAMPLING_MODES})
class GraphGenGenerateView(APIView):
renderer_classes = [EventStreamRenderer, JSONRenderer]
def post(self, request):
data = request.data
registry = ModelRegistry.get()
dataset_id = data.get("dataset_id")
if dataset_id not in GRAPHGEN_DATASETS:
raise InvalidRequestError(
f"Unknown dataset_id '{dataset_id}'. Valid: {list(GRAPHGEN_DATASETS)}")
model_type = data.get("model_type")
if model_type not in ("discrete", "continuous"):
raise InvalidRequestError("model_type must be 'discrete' or 'continuous'")
sampling_mode = data.get("sampling_mode")
if sampling_mode not in ("standard", "multiprox"):
raise InvalidRequestError("sampling_mode must be 'standard' or 'multiprox'")
available = registry.graphgen_checkpoints_available.get(dataset_id, [])
if model_type not in available:
raise ModelUnavailable(
f"No {model_type} checkpoint available for dataset '{dataset_id}'")
num_nodes = data.get("num_nodes")
if num_nodes is not None:
max_nodes = GRAPHGEN_DATASETS[dataset_id]["max_nodes"]
if not isinstance(num_nodes, int) or not (1 <= num_nodes <= max_nodes):
raise InvalidRequestError(
f"num_nodes must be an integer in [1, {max_nodes}] for dataset '{dataset_id}'")
if sampling_mode == "standard":
diffusion_steps = min(max(int(data.get("diffusion_steps", 500)), 50), 1000)
chain_frames = min(max(int(data.get("chain_frames", 10)), 10), 30)
gen = registry.graphgen_generate_stream(
dataset_id, model_type, sampling_mode, num_nodes,
diffusion_steps, chain_frames, None)
else:
mp = data.get("multiprox_params")
if not mp or not isinstance(mp, dict):
raise InvalidRequestError("multiprox_params is required for multiprox sampling_mode")
m = int(mp.get("m", 100))
if not (2 <= m <= 100):
raise InvalidRequestError("multiprox_params.m must be in [2, 100]")
n = int(mp.get("n", 10))
if n < 1:
raise InvalidRequestError("multiprox_params.n must be >= 1")
t = float(mp.get("t", 0.5))
t_prime = float(mp.get("t_prime", 0.1))
if not (0 < t_prime <= t <= 1):
raise InvalidRequestError(
"multiprox_params must satisfy 0 < t_prime <= t <= 1")
gibbs_chain_freq = int(mp.get("gibbs_chain_freq", max(1, m // 10)))
if not (1 <= gibbs_chain_freq <= m):
raise InvalidRequestError(
f"multiprox_params.gibbs_chain_freq must be in [1, {m}]")
multiprox_params = {
"n": n, "m": m, "t": t, "t_prime": t_prime,
"gibbs_chain_freq": gibbs_chain_freq,
}
gen = registry.graphgen_generate_stream(
dataset_id, model_type, sampling_mode, num_nodes,
None, None, multiprox_params)
return _streaming_sse_response(gen)
class GraphGenContinueView(APIView):
renderer_classes = [EventStreamRenderer, JSONRenderer]
def post(self, request):
state_b64 = request.data.get("state")
if not state_b64 or not isinstance(state_b64, str):
raise InvalidRequestError("'state' is required and must be a non-empty string")
gen = ModelRegistry.get().graphgen_continue_stream(state_b64)
return _streaming_sse_response(gen)
def _streaming_sse_response(gen):
"""Build a StreamingHttpResponse with SSE format and anti-buffering headers."""
resp = StreamingHttpResponse(
_sse_iter(gen),
content_type="text/event-stream",
)
resp["Cache-Control"] = "no-cache"
resp["X-Accel-Buffering"] = "no" # nginx
return resp
def _sse_iter(gen):
"""Convert generator of dicts into SSE events.
Progress events are split: metadata goes in ``event: progress``,
and preview images (when present) go in a separate ``event: preview``
so Postman's SSE viewer shows clean image updates.
"""
try:
for event in gen:
etype = event.get("type", "message")
preview = event.pop("preview", None)
yield f"event: {etype}\ndata: {json.dumps(event, separators=(',', ':'))}\n\n"
if preview:
yield f"event: preview\ndata: {preview}\n\n"
except Exception:
gen.close()
raise