File size: 5,905 Bytes
c6841f4 713aa17 4f1e196 c6841f4 713aa17 4f1e196 c6841f4 713aa17 c6841f4 713aa17 c6841f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import json
from django.http import StreamingHttpResponse
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response
from rest_framework.views import APIView
from api.exceptions import InvalidRequestError, ModelUnavailable
from api.renderers import EventStreamRenderer
from api.services.constants import GRAPHGEN_DATASETS, GRAPHGEN_SAMPLING_MODES
from api.services.registry import ModelRegistry
class GraphGenDatasetsView(APIView):
def get(self, request):
registry = ModelRegistry.get()
datasets = []
for dataset_id, meta in GRAPHGEN_DATASETS.items():
available_types = registry.graphgen_checkpoints_available.get(dataset_id, [])
datasets.append({
"id": dataset_id,
"name": meta["name"],
"type": meta["type"],
"description": meta["description"],
"node_types": meta["node_types"],
"edge_types": meta["edge_types"],
"max_nodes": meta["max_nodes"],
"available_model_types": available_types,
})
return Response({"datasets": datasets})
class GraphGenSamplingModesView(APIView):
def get(self, request):
return Response({"sampling_modes": GRAPHGEN_SAMPLING_MODES})
class GraphGenGenerateView(APIView):
renderer_classes = [EventStreamRenderer, JSONRenderer]
def post(self, request):
data = request.data
registry = ModelRegistry.get()
dataset_id = data.get("dataset_id")
if dataset_id not in GRAPHGEN_DATASETS:
raise InvalidRequestError(
f"Unknown dataset_id '{dataset_id}'. Valid: {list(GRAPHGEN_DATASETS)}")
model_type = data.get("model_type")
if model_type not in ("discrete", "continuous"):
raise InvalidRequestError("model_type must be 'discrete' or 'continuous'")
sampling_mode = data.get("sampling_mode")
if sampling_mode not in ("standard", "multiprox"):
raise InvalidRequestError("sampling_mode must be 'standard' or 'multiprox'")
available = registry.graphgen_checkpoints_available.get(dataset_id, [])
if model_type not in available:
raise ModelUnavailable(
f"No {model_type} checkpoint available for dataset '{dataset_id}'")
num_nodes = data.get("num_nodes")
if num_nodes is not None:
max_nodes = GRAPHGEN_DATASETS[dataset_id]["max_nodes"]
if not isinstance(num_nodes, int) or not (1 <= num_nodes <= max_nodes):
raise InvalidRequestError(
f"num_nodes must be an integer in [1, {max_nodes}] for dataset '{dataset_id}'")
if sampling_mode == "standard":
diffusion_steps = min(max(int(data.get("diffusion_steps", 500)), 50), 1000)
chain_frames = min(max(int(data.get("chain_frames", 10)), 10), 30)
gen = registry.graphgen_generate_stream(
dataset_id, model_type, sampling_mode, num_nodes,
diffusion_steps, chain_frames, None)
else:
mp = data.get("multiprox_params")
if not mp or not isinstance(mp, dict):
raise InvalidRequestError("multiprox_params is required for multiprox sampling_mode")
m = int(mp.get("m", 100))
if not (2 <= m <= 100):
raise InvalidRequestError("multiprox_params.m must be in [2, 100]")
n = int(mp.get("n", 10))
if n < 1:
raise InvalidRequestError("multiprox_params.n must be >= 1")
t = float(mp.get("t", 0.5))
t_prime = float(mp.get("t_prime", 0.1))
if not (0 < t_prime <= t <= 1):
raise InvalidRequestError(
"multiprox_params must satisfy 0 < t_prime <= t <= 1")
gibbs_chain_freq = int(mp.get("gibbs_chain_freq", max(1, m // 10)))
if not (1 <= gibbs_chain_freq <= m):
raise InvalidRequestError(
f"multiprox_params.gibbs_chain_freq must be in [1, {m}]")
multiprox_params = {
"n": n, "m": m, "t": t, "t_prime": t_prime,
"gibbs_chain_freq": gibbs_chain_freq,
}
gen = registry.graphgen_generate_stream(
dataset_id, model_type, sampling_mode, num_nodes,
None, None, multiprox_params)
return _streaming_sse_response(gen)
class GraphGenContinueView(APIView):
renderer_classes = [EventStreamRenderer, JSONRenderer]
def post(self, request):
state_b64 = request.data.get("state")
if not state_b64 or not isinstance(state_b64, str):
raise InvalidRequestError("'state' is required and must be a non-empty string")
gen = ModelRegistry.get().graphgen_continue_stream(state_b64)
return _streaming_sse_response(gen)
def _streaming_sse_response(gen):
"""Build a StreamingHttpResponse with SSE format and anti-buffering headers."""
resp = StreamingHttpResponse(
_sse_iter(gen),
content_type="text/event-stream",
)
resp["Cache-Control"] = "no-cache"
resp["X-Accel-Buffering"] = "no" # nginx
return resp
def _sse_iter(gen):
"""Convert generator of dicts into SSE events.
Progress events are split: metadata goes in ``event: progress``,
and preview images (when present) go in a separate ``event: preview``
so Postman's SSE viewer shows clean image updates.
"""
try:
for event in gen:
etype = event.get("type", "message")
preview = event.pop("preview", None)
yield f"event: {etype}\ndata: {json.dumps(event, separators=(',', ':'))}\n\n"
if preview:
yield f"event: preview\ndata: {preview}\n\n"
except Exception:
gen.close()
raise
|