File size: 4,493 Bytes
bf2ae7f
 
58fd52b
bf2ae7f
c6540a4
bf2ae7f
 
 
 
6eb0631
bf2ae7f
 
410f2c5
2778014
 
 
 
 
40d8a9a
2778014
410f2c5
2778014
 
 
bf2ae7f
 
410f2c5
4d89c4d
2778014
bf2ae7f
 
 
 
 
 
 
 
 
 
 
2778014
 
 
 
 
4d89c4d
bf2ae7f
2778014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d89c4d
2778014
 
 
bf2ae7f
2778014
4d89c4d
 
410f2c5
2778014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf2ae7f
 
 
 
410f2c5
bf2ae7f
 
 
 
410f2c5
bf2ae7f
 
 
 
 
2778014
 
 
 
 
 
 
 
 
bf2ae7f
ef0bf4f
bf2ae7f
 
 
 
 
2778014
 
 
 
bf2ae7f
eae7491
bf2ae7f
 
 
 
58fd52b
4d89c4d
 
 
58fd52b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import tempfile

from models import (
    GenerateRequest, GenerateResponse,
    PreviewRequest, PreviewResponse,
    HealthResponse, OpenRouterModel
)

OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "")
data_designer = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global data_designer
    from data_designer.interface import DataDesigner
    data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
    yield


app = FastAPI(
    title="NeMo DataDesigner API",
    description="Synthetic data generation with DataDesigner + OpenRouter",
    version="5.3.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def build_config(request):
    import data_designer.config as dd
    from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams

    config_builder = dd.DataDesignerConfigBuilder()
    model_id = request.model.value

    for col in request.columns:
        if col.type == "sampler":
            sampler_type_str = col.params.get("sampler_type", "CATEGORY")
            sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
            params = get_sampler_params(sampler_type, col.params)
            config_builder.add_column(
                dd.SamplerColumnConfig(
                    name=col.name,
                    sampler_type=sampler_type,
                    params=params,
                )
            )
        elif col.type == "llm_text":
            config_builder.add_column(
                dd.LLMTextColumnConfig(
                    name=col.name,
                    model_alias="or-model",
                    prompt=col.params.get("prompt", "Generate text"),
                )
            )

    model_config = ModelConfig(
        alias="or-model",
        model=model_id,
        provider="openrouter",
        inference_parameters=ChatCompletionInferenceParams(
            temperature=request.temperature,
            max_tokens=request.max_tokens,
        ),
    )
    config_builder.add_model_config(model_config)

    return config_builder


def get_sampler_params(sampler_type, params):
    import data_designer.config as dd
    type_name = sampler_type.name if hasattr(sampler_type, "name") else str(sampler_type)

    if type_name == "CATEGORY":
        return dd.CategorySamplerParams(values=params.get("values", ["A", "B", "C"]))
    else:
        return dd.CategorySamplerParams(values=["default"])


@app.get("/", response_model=HealthResponse)
async def root():
    return HealthResponse(status="healthy", model="data-designer", api_configured=bool(OPENROUTER_API_KEY))


@app.get("/health", response_model=HealthResponse)
async def health():
    return HealthResponse(status="healthy", model="data-designer", api_configured=bool(OPENROUTER_API_KEY))


@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
    try:
        config_builder = build_config(request)
        result = data_designer.create(
            config_builder=config_builder,
            num_records=request.num_records,
            dataset_name="api-dataset"
        )
        df = result.load_dataset()
        data = df.to_dict(orient="records")
        return GenerateResponse(success=True, data=data, record_count=len(data))
    except Exception as e:
        return GenerateResponse(success=False, error=str(e))


@app.post("/preview", response_model=PreviewResponse)
async def preview(request: PreviewRequest):
    try:
        config_builder = build_config(request)
        preview_result = data_designer.preview(config_builder=config_builder, num_records=1)
        sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
        return PreviewResponse(success=True, sample=sample)
    except Exception as e:
        return PreviewResponse(success=False, error=str(e))


@app.get("/models")
async def list_models():
    return {"models": [
        {"id": "z-ai/glm-5", "name": "GLM-5", "description": "z.ai flagship"},
        {"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini", "description": "Fast & cheap"},
        {"id": "anthropic/claude-sonnet-4.6", "name": "Claude Sonnet 4.6", "description": "Latest Claude"}
    ]}