Andrej Janchevski
feat(multiproxan): stream graph-generation progress and refine comm20 render
713aa17 | import json | |
| from django.http import StreamingHttpResponse | |
| from rest_framework.renderers import JSONRenderer | |
| from rest_framework.response import Response | |
| from rest_framework.views import APIView | |
| from api.exceptions import InvalidRequestError, ModelUnavailable | |
| from api.renderers import EventStreamRenderer | |
| from api.services.constants import GRAPHGEN_DATASETS, GRAPHGEN_SAMPLING_MODES | |
| from api.services.registry import ModelRegistry | |
| class GraphGenDatasetsView(APIView): | |
| def get(self, request): | |
| registry = ModelRegistry.get() | |
| datasets = [] | |
| for dataset_id, meta in GRAPHGEN_DATASETS.items(): | |
| available_types = registry.graphgen_checkpoints_available.get(dataset_id, []) | |
| datasets.append({ | |
| "id": dataset_id, | |
| "name": meta["name"], | |
| "type": meta["type"], | |
| "description": meta["description"], | |
| "node_types": meta["node_types"], | |
| "edge_types": meta["edge_types"], | |
| "max_nodes": meta["max_nodes"], | |
| "available_model_types": available_types, | |
| }) | |
| return Response({"datasets": datasets}) | |
| class GraphGenSamplingModesView(APIView): | |
| def get(self, request): | |
| return Response({"sampling_modes": GRAPHGEN_SAMPLING_MODES}) | |
| class GraphGenGenerateView(APIView): | |
| renderer_classes = [EventStreamRenderer, JSONRenderer] | |
| def post(self, request): | |
| data = request.data | |
| registry = ModelRegistry.get() | |
| dataset_id = data.get("dataset_id") | |
| if dataset_id not in GRAPHGEN_DATASETS: | |
| raise InvalidRequestError( | |
| f"Unknown dataset_id '{dataset_id}'. Valid: {list(GRAPHGEN_DATASETS)}") | |
| model_type = data.get("model_type") | |
| if model_type not in ("discrete", "continuous"): | |
| raise InvalidRequestError("model_type must be 'discrete' or 'continuous'") | |
| sampling_mode = data.get("sampling_mode") | |
| if sampling_mode not in ("standard", "multiprox"): | |
| raise InvalidRequestError("sampling_mode must be 'standard' or 'multiprox'") | |
| available = registry.graphgen_checkpoints_available.get(dataset_id, []) | |
| if model_type not in available: | |
| raise ModelUnavailable( | |
| f"No {model_type} checkpoint available for dataset '{dataset_id}'") | |
| num_nodes = data.get("num_nodes") | |
| if num_nodes is not None: | |
| max_nodes = GRAPHGEN_DATASETS[dataset_id]["max_nodes"] | |
| if not isinstance(num_nodes, int) or not (1 <= num_nodes <= max_nodes): | |
| raise InvalidRequestError( | |
| f"num_nodes must be an integer in [1, {max_nodes}] for dataset '{dataset_id}'") | |
| if sampling_mode == "standard": | |
| diffusion_steps = min(max(int(data.get("diffusion_steps", 500)), 50), 1000) | |
| chain_frames = min(max(int(data.get("chain_frames", 10)), 10), 30) | |
| gen = registry.graphgen_generate_stream( | |
| dataset_id, model_type, sampling_mode, num_nodes, | |
| diffusion_steps, chain_frames, None) | |
| else: | |
| mp = data.get("multiprox_params") | |
| if not mp or not isinstance(mp, dict): | |
| raise InvalidRequestError("multiprox_params is required for multiprox sampling_mode") | |
| m = int(mp.get("m", 100)) | |
| if not (2 <= m <= 100): | |
| raise InvalidRequestError("multiprox_params.m must be in [2, 100]") | |
| n = int(mp.get("n", 10)) | |
| if n < 1: | |
| raise InvalidRequestError("multiprox_params.n must be >= 1") | |
| t = float(mp.get("t", 0.5)) | |
| t_prime = float(mp.get("t_prime", 0.1)) | |
| if not (0 < t_prime <= t <= 1): | |
| raise InvalidRequestError( | |
| "multiprox_params must satisfy 0 < t_prime <= t <= 1") | |
| gibbs_chain_freq = int(mp.get("gibbs_chain_freq", max(1, m // 10))) | |
| if not (1 <= gibbs_chain_freq <= m): | |
| raise InvalidRequestError( | |
| f"multiprox_params.gibbs_chain_freq must be in [1, {m}]") | |
| multiprox_params = { | |
| "n": n, "m": m, "t": t, "t_prime": t_prime, | |
| "gibbs_chain_freq": gibbs_chain_freq, | |
| } | |
| gen = registry.graphgen_generate_stream( | |
| dataset_id, model_type, sampling_mode, num_nodes, | |
| None, None, multiprox_params) | |
| return _streaming_sse_response(gen) | |
| class GraphGenContinueView(APIView): | |
| renderer_classes = [EventStreamRenderer, JSONRenderer] | |
| def post(self, request): | |
| state_b64 = request.data.get("state") | |
| if not state_b64 or not isinstance(state_b64, str): | |
| raise InvalidRequestError("'state' is required and must be a non-empty string") | |
| gen = ModelRegistry.get().graphgen_continue_stream(state_b64) | |
| return _streaming_sse_response(gen) | |
| def _streaming_sse_response(gen): | |
| """Build a StreamingHttpResponse with SSE format and anti-buffering headers.""" | |
| resp = StreamingHttpResponse( | |
| _sse_iter(gen), | |
| content_type="text/event-stream", | |
| ) | |
| resp["Cache-Control"] = "no-cache" | |
| resp["X-Accel-Buffering"] = "no" # nginx | |
| return resp | |
| def _sse_iter(gen): | |
| """Convert generator of dicts into SSE events. | |
| Progress events are split: metadata goes in ``event: progress``, | |
| and preview images (when present) go in a separate ``event: preview`` | |
| so Postman's SSE viewer shows clean image updates. | |
| """ | |
| try: | |
| for event in gen: | |
| etype = event.get("type", "message") | |
| preview = event.pop("preview", None) | |
| yield f"event: {etype}\ndata: {json.dumps(event, separators=(',', ':'))}\n\n" | |
| if preview: | |
| yield f"event: preview\ndata: {preview}\n\n" | |
| except Exception: | |
| gen.close() | |
| raise | |