mindchain's picture
v5.3: Use z-ai/glm-5 via OpenRouter
4d89c4d verified
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import tempfile
from models import (
GenerateRequest, GenerateResponse,
PreviewRequest, PreviewResponse,
HealthResponse, OpenRouterModel
)
OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "")
data_designer = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global data_designer
from data_designer.interface import DataDesigner
data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
yield
app = FastAPI(
title="NeMo DataDesigner API",
description="Synthetic data generation with DataDesigner + OpenRouter",
version="5.3.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def build_config(request):
import data_designer.config as dd
from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams
config_builder = dd.DataDesignerConfigBuilder()
model_id = request.model.value
for col in request.columns:
if col.type == "sampler":
sampler_type_str = col.params.get("sampler_type", "CATEGORY")
sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
params = get_sampler_params(sampler_type, col.params)
config_builder.add_column(
dd.SamplerColumnConfig(
name=col.name,
sampler_type=sampler_type,
params=params,
)
)
elif col.type == "llm_text":
config_builder.add_column(
dd.LLMTextColumnConfig(
name=col.name,
model_alias="or-model",
prompt=col.params.get("prompt", "Generate text"),
)
)
model_config = ModelConfig(
alias="or-model",
model=model_id,
provider="openrouter",
inference_parameters=ChatCompletionInferenceParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
),
)
config_builder.add_model_config(model_config)
return config_builder
def get_sampler_params(sampler_type, params):
import data_designer.config as dd
type_name = sampler_type.name if hasattr(sampler_type, "name") else str(sampler_type)
if type_name == "CATEGORY":
return dd.CategorySamplerParams(values=params.get("values", ["A", "B", "C"]))
else:
return dd.CategorySamplerParams(values=["default"])
@app.get("/", response_model=HealthResponse)
async def root():
return HealthResponse(status="healthy", model="data-designer", api_configured=bool(OPENROUTER_API_KEY))
@app.get("/health", response_model=HealthResponse)
async def health():
return HealthResponse(status="healthy", model="data-designer", api_configured=bool(OPENROUTER_API_KEY))
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
try:
config_builder = build_config(request)
result = data_designer.create(
config_builder=config_builder,
num_records=request.num_records,
dataset_name="api-dataset"
)
df = result.load_dataset()
data = df.to_dict(orient="records")
return GenerateResponse(success=True, data=data, record_count=len(data))
except Exception as e:
return GenerateResponse(success=False, error=str(e))
@app.post("/preview", response_model=PreviewResponse)
async def preview(request: PreviewRequest):
try:
config_builder = build_config(request)
preview_result = data_designer.preview(config_builder=config_builder, num_records=1)
sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
return PreviewResponse(success=True, sample=sample)
except Exception as e:
return PreviewResponse(success=False, error=str(e))
@app.get("/models")
async def list_models():
return {"models": [
{"id": "z-ai/glm-5", "name": "GLM-5", "description": "z.ai flagship"},
{"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini", "description": "Fast & cheap"},
{"id": "anthropic/claude-sonnet-4.6", "name": "Claude Sonnet 4.6", "description": "Latest Claude"}
]}