import json from django.http import StreamingHttpResponse from rest_framework.renderers import JSONRenderer from rest_framework.response import Response from rest_framework.views import APIView from api.exceptions import InvalidRequestError, ModelUnavailable from api.renderers import EventStreamRenderer from api.services.constants import GRAPHGEN_DATASETS, GRAPHGEN_SAMPLING_MODES from api.services.registry import ModelRegistry class GraphGenDatasetsView(APIView): def get(self, request): registry = ModelRegistry.get() datasets = [] for dataset_id, meta in GRAPHGEN_DATASETS.items(): available_types = registry.graphgen_checkpoints_available.get(dataset_id, []) datasets.append({ "id": dataset_id, "name": meta["name"], "type": meta["type"], "description": meta["description"], "node_types": meta["node_types"], "edge_types": meta["edge_types"], "max_nodes": meta["max_nodes"], "available_model_types": available_types, }) return Response({"datasets": datasets}) class GraphGenSamplingModesView(APIView): def get(self, request): return Response({"sampling_modes": GRAPHGEN_SAMPLING_MODES}) class GraphGenGenerateView(APIView): renderer_classes = [EventStreamRenderer, JSONRenderer] def post(self, request): data = request.data registry = ModelRegistry.get() dataset_id = data.get("dataset_id") if dataset_id not in GRAPHGEN_DATASETS: raise InvalidRequestError( f"Unknown dataset_id '{dataset_id}'. Valid: {list(GRAPHGEN_DATASETS)}") model_type = data.get("model_type") if model_type not in ("discrete", "continuous"): raise InvalidRequestError("model_type must be 'discrete' or 'continuous'") sampling_mode = data.get("sampling_mode") if sampling_mode not in ("standard", "multiprox"): raise InvalidRequestError("sampling_mode must be 'standard' or 'multiprox'") available = registry.graphgen_checkpoints_available.get(dataset_id, []) if model_type not in available: raise ModelUnavailable( f"No {model_type} checkpoint available for dataset '{dataset_id}'") num_nodes = data.get("num_nodes") if num_nodes is not None: max_nodes = GRAPHGEN_DATASETS[dataset_id]["max_nodes"] if not isinstance(num_nodes, int) or not (1 <= num_nodes <= max_nodes): raise InvalidRequestError( f"num_nodes must be an integer in [1, {max_nodes}] for dataset '{dataset_id}'") if sampling_mode == "standard": diffusion_steps = min(max(int(data.get("diffusion_steps", 500)), 50), 1000) chain_frames = min(max(int(data.get("chain_frames", 10)), 10), 30) gen = registry.graphgen_generate_stream( dataset_id, model_type, sampling_mode, num_nodes, diffusion_steps, chain_frames, None) else: mp = data.get("multiprox_params") if not mp or not isinstance(mp, dict): raise InvalidRequestError("multiprox_params is required for multiprox sampling_mode") m = int(mp.get("m", 100)) if not (2 <= m <= 100): raise InvalidRequestError("multiprox_params.m must be in [2, 100]") n = int(mp.get("n", 10)) if n < 1: raise InvalidRequestError("multiprox_params.n must be >= 1") t = float(mp.get("t", 0.5)) t_prime = float(mp.get("t_prime", 0.1)) if not (0 < t_prime <= t <= 1): raise InvalidRequestError( "multiprox_params must satisfy 0 < t_prime <= t <= 1") gibbs_chain_freq = int(mp.get("gibbs_chain_freq", max(1, m // 10))) if not (1 <= gibbs_chain_freq <= m): raise InvalidRequestError( f"multiprox_params.gibbs_chain_freq must be in [1, {m}]") multiprox_params = { "n": n, "m": m, "t": t, "t_prime": t_prime, "gibbs_chain_freq": gibbs_chain_freq, } gen = registry.graphgen_generate_stream( dataset_id, model_type, sampling_mode, num_nodes, None, None, multiprox_params) return _streaming_sse_response(gen) class GraphGenContinueView(APIView): renderer_classes = [EventStreamRenderer, JSONRenderer] def post(self, request): state_b64 = request.data.get("state") if not state_b64 or not isinstance(state_b64, str): raise InvalidRequestError("'state' is required and must be a non-empty string") gen = ModelRegistry.get().graphgen_continue_stream(state_b64) return _streaming_sse_response(gen) def _streaming_sse_response(gen): """Build a StreamingHttpResponse with SSE format and anti-buffering headers.""" resp = StreamingHttpResponse( _sse_iter(gen), content_type="text/event-stream", ) resp["Cache-Control"] = "no-cache" resp["X-Accel-Buffering"] = "no" # nginx return resp def _sse_iter(gen): """Convert generator of dicts into SSE events. Progress events are split: metadata goes in ``event: progress``, and preview images (when present) go in a separate ``event: preview`` so Postman's SSE viewer shows clean image updates. """ try: for event in gen: etype = event.get("type", "message") preview = event.pop("preview", None) yield f"event: {etype}\ndata: {json.dumps(event, separators=(',', ':'))}\n\n" if preview: yield f"event: preview\ndata: {preview}\n\n" except Exception: gen.close() raise