File size: 5,905 Bytes
c6841f4
 
 
713aa17
4f1e196
 
 
c6841f4
713aa17
4f1e196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6841f4
 
 
713aa17
 
c6841f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713aa17
 
c6841f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import json

from django.http import StreamingHttpResponse
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response
from rest_framework.views import APIView

from api.exceptions import InvalidRequestError, ModelUnavailable
from api.renderers import EventStreamRenderer
from api.services.constants import GRAPHGEN_DATASETS, GRAPHGEN_SAMPLING_MODES
from api.services.registry import ModelRegistry


class GraphGenDatasetsView(APIView):
    def get(self, request):
        registry = ModelRegistry.get()
        datasets = []
        for dataset_id, meta in GRAPHGEN_DATASETS.items():
            available_types = registry.graphgen_checkpoints_available.get(dataset_id, [])
            datasets.append({
                "id": dataset_id,
                "name": meta["name"],
                "type": meta["type"],
                "description": meta["description"],
                "node_types": meta["node_types"],
                "edge_types": meta["edge_types"],
                "max_nodes": meta["max_nodes"],
                "available_model_types": available_types,
            })
        return Response({"datasets": datasets})


class GraphGenSamplingModesView(APIView):
    def get(self, request):
        return Response({"sampling_modes": GRAPHGEN_SAMPLING_MODES})


class GraphGenGenerateView(APIView):
    renderer_classes = [EventStreamRenderer, JSONRenderer]

    def post(self, request):
        data = request.data
        registry = ModelRegistry.get()

        dataset_id = data.get("dataset_id")
        if dataset_id not in GRAPHGEN_DATASETS:
            raise InvalidRequestError(
                f"Unknown dataset_id '{dataset_id}'. Valid: {list(GRAPHGEN_DATASETS)}")

        model_type = data.get("model_type")
        if model_type not in ("discrete", "continuous"):
            raise InvalidRequestError("model_type must be 'discrete' or 'continuous'")

        sampling_mode = data.get("sampling_mode")
        if sampling_mode not in ("standard", "multiprox"):
            raise InvalidRequestError("sampling_mode must be 'standard' or 'multiprox'")

        available = registry.graphgen_checkpoints_available.get(dataset_id, [])
        if model_type not in available:
            raise ModelUnavailable(
                f"No {model_type} checkpoint available for dataset '{dataset_id}'")

        num_nodes = data.get("num_nodes")
        if num_nodes is not None:
            max_nodes = GRAPHGEN_DATASETS[dataset_id]["max_nodes"]
            if not isinstance(num_nodes, int) or not (1 <= num_nodes <= max_nodes):
                raise InvalidRequestError(
                    f"num_nodes must be an integer in [1, {max_nodes}] for dataset '{dataset_id}'")

        if sampling_mode == "standard":
            diffusion_steps = min(max(int(data.get("diffusion_steps", 500)), 50), 1000)
            chain_frames = min(max(int(data.get("chain_frames", 10)), 10), 30)
            gen = registry.graphgen_generate_stream(
                dataset_id, model_type, sampling_mode, num_nodes,
                diffusion_steps, chain_frames, None)
        else:
            mp = data.get("multiprox_params")
            if not mp or not isinstance(mp, dict):
                raise InvalidRequestError("multiprox_params is required for multiprox sampling_mode")

            m = int(mp.get("m", 100))
            if not (2 <= m <= 100):
                raise InvalidRequestError("multiprox_params.m must be in [2, 100]")

            n = int(mp.get("n", 10))
            if n < 1:
                raise InvalidRequestError("multiprox_params.n must be >= 1")

            t = float(mp.get("t", 0.5))
            t_prime = float(mp.get("t_prime", 0.1))
            if not (0 < t_prime <= t <= 1):
                raise InvalidRequestError(
                    "multiprox_params must satisfy 0 < t_prime <= t <= 1")

            gibbs_chain_freq = int(mp.get("gibbs_chain_freq", max(1, m // 10)))
            if not (1 <= gibbs_chain_freq <= m):
                raise InvalidRequestError(
                    f"multiprox_params.gibbs_chain_freq must be in [1, {m}]")

            multiprox_params = {
                "n": n, "m": m, "t": t, "t_prime": t_prime,
                "gibbs_chain_freq": gibbs_chain_freq,
            }
            gen = registry.graphgen_generate_stream(
                dataset_id, model_type, sampling_mode, num_nodes,
                None, None, multiprox_params)

        return _streaming_sse_response(gen)


class GraphGenContinueView(APIView):
    renderer_classes = [EventStreamRenderer, JSONRenderer]

    def post(self, request):
        state_b64 = request.data.get("state")
        if not state_b64 or not isinstance(state_b64, str):
            raise InvalidRequestError("'state' is required and must be a non-empty string")
        gen = ModelRegistry.get().graphgen_continue_stream(state_b64)
        return _streaming_sse_response(gen)


def _streaming_sse_response(gen):
    """Build a StreamingHttpResponse with SSE format and anti-buffering headers."""
    resp = StreamingHttpResponse(
        _sse_iter(gen),
        content_type="text/event-stream",
    )
    resp["Cache-Control"] = "no-cache"
    resp["X-Accel-Buffering"] = "no"  # nginx
    return resp


def _sse_iter(gen):
    """Convert generator of dicts into SSE events.

    Progress events are split: metadata goes in ``event: progress``,
    and preview images (when present) go in a separate ``event: preview``
    so Postman's SSE viewer shows clean image updates.
    """
    try:
        for event in gen:
            etype = event.get("type", "message")
            preview = event.pop("preview", None)
            yield f"event: {etype}\ndata: {json.dumps(event, separators=(',', ':'))}\n\n"
            if preview:
                yield f"event: preview\ndata: {preview}\n\n"
    except Exception:
        gen.close()
        raise