Spaces:
Sleeping
Sleeping
File size: 4,493 Bytes
bf2ae7f 58fd52b bf2ae7f c6540a4 bf2ae7f 6eb0631 bf2ae7f 410f2c5 2778014 40d8a9a 2778014 410f2c5 2778014 bf2ae7f 410f2c5 4d89c4d 2778014 bf2ae7f 2778014 4d89c4d bf2ae7f 2778014 4d89c4d 2778014 bf2ae7f 2778014 4d89c4d 410f2c5 2778014 bf2ae7f 410f2c5 bf2ae7f 410f2c5 bf2ae7f 2778014 bf2ae7f ef0bf4f bf2ae7f 2778014 bf2ae7f eae7491 bf2ae7f 58fd52b 4d89c4d 58fd52b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import tempfile
from models import (
GenerateRequest, GenerateResponse,
PreviewRequest, PreviewResponse,
HealthResponse, OpenRouterModel
)
OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "")
data_designer = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global data_designer
from data_designer.interface import DataDesigner
data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
yield
app = FastAPI(
title="NeMo DataDesigner API",
description="Synthetic data generation with DataDesigner + OpenRouter",
version="5.3.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def build_config(request):
import data_designer.config as dd
from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams
config_builder = dd.DataDesignerConfigBuilder()
model_id = request.model.value
for col in request.columns:
if col.type == "sampler":
sampler_type_str = col.params.get("sampler_type", "CATEGORY")
sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
params = get_sampler_params(sampler_type, col.params)
config_builder.add_column(
dd.SamplerColumnConfig(
name=col.name,
sampler_type=sampler_type,
params=params,
)
)
elif col.type == "llm_text":
config_builder.add_column(
dd.LLMTextColumnConfig(
name=col.name,
model_alias="or-model",
prompt=col.params.get("prompt", "Generate text"),
)
)
model_config = ModelConfig(
alias="or-model",
model=model_id,
provider="openrouter",
inference_parameters=ChatCompletionInferenceParams(
temperature=request.temperature,
max_tokens=request.max_tokens,
),
)
config_builder.add_model_config(model_config)
return config_builder
def get_sampler_params(sampler_type, params):
import data_designer.config as dd
type_name = sampler_type.name if hasattr(sampler_type, "name") else str(sampler_type)
if type_name == "CATEGORY":
return dd.CategorySamplerParams(values=params.get("values", ["A", "B", "C"]))
else:
return dd.CategorySamplerParams(values=["default"])
@app.get("/", response_model=HealthResponse)
async def root():
return HealthResponse(status="healthy", model="data-designer", api_configured=bool(OPENROUTER_API_KEY))
@app.get("/health", response_model=HealthResponse)
async def health():
return HealthResponse(status="healthy", model="data-designer", api_configured=bool(OPENROUTER_API_KEY))
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
try:
config_builder = build_config(request)
result = data_designer.create(
config_builder=config_builder,
num_records=request.num_records,
dataset_name="api-dataset"
)
df = result.load_dataset()
data = df.to_dict(orient="records")
return GenerateResponse(success=True, data=data, record_count=len(data))
except Exception as e:
return GenerateResponse(success=False, error=str(e))
@app.post("/preview", response_model=PreviewResponse)
async def preview(request: PreviewRequest):
try:
config_builder = build_config(request)
preview_result = data_designer.preview(config_builder=config_builder, num_records=1)
sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
return PreviewResponse(success=True, sample=sample)
except Exception as e:
return PreviewResponse(success=False, error=str(e))
@app.get("/models")
async def list_models():
return {"models": [
{"id": "z-ai/glm-5", "name": "GLM-5", "description": "z.ai flagship"},
{"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini", "description": "Fast & cheap"},
{"id": "anthropic/claude-sonnet-4.6", "name": "Claude Sonnet 4.6", "description": "Latest Claude"}
]}
|