import os from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import tempfile from models import ( GenerateRequest, GenerateResponse, PreviewRequest, PreviewResponse, HealthResponse, OpenRouterModel ) OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "") data_designer = None @asynccontextmanager async def lifespan(app: FastAPI): global data_designer from data_designer.interface import DataDesigner data_designer = DataDesigner(artifact_path=tempfile.gettempdir()) yield app = FastAPI( title="NeMo DataDesigner API", description="Synthetic data generation with DataDesigner + OpenRouter", version="5.3.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def build_config(request): import data_designer.config as dd from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams config_builder = dd.DataDesignerConfigBuilder() model_id = request.model.value for col in request.columns: if col.type == "sampler": sampler_type_str = col.params.get("sampler_type", "CATEGORY") sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY) params = get_sampler_params(sampler_type, col.params) config_builder.add_column( dd.SamplerColumnConfig( name=col.name, sampler_type=sampler_type, params=params, ) ) elif col.type == "llm_text": config_builder.add_column( dd.LLMTextColumnConfig( name=col.name, model_alias="or-model", prompt=col.params.get("prompt", "Generate text"), ) ) model_config = ModelConfig( alias="or-model", model=model_id, provider="openrouter", inference_parameters=ChatCompletionInferenceParams( temperature=request.temperature, max_tokens=request.max_tokens, ), ) config_builder.add_model_config(model_config) return config_builder def get_sampler_params(sampler_type, params): import data_designer.config as dd type_name = sampler_type.name if hasattr(sampler_type, "name") else str(sampler_type) if type_name == "CATEGORY": return dd.CategorySamplerParams(values=params.get("values", ["A", "B", "C"])) else: return dd.CategorySamplerParams(values=["default"]) @app.get("/", response_model=HealthResponse) async def root(): return HealthResponse(status="healthy", model="data-designer", api_configured=bool(OPENROUTER_API_KEY)) @app.get("/health", response_model=HealthResponse) async def health(): return HealthResponse(status="healthy", model="data-designer", api_configured=bool(OPENROUTER_API_KEY)) @app.post("/generate", response_model=GenerateResponse) async def generate(request: GenerateRequest): try: config_builder = build_config(request) result = data_designer.create( config_builder=config_builder, num_records=request.num_records, dataset_name="api-dataset" ) df = result.load_dataset() data = df.to_dict(orient="records") return GenerateResponse(success=True, data=data, record_count=len(data)) except Exception as e: return GenerateResponse(success=False, error=str(e)) @app.post("/preview", response_model=PreviewResponse) async def preview(request: PreviewRequest): try: config_builder = build_config(request) preview_result = data_designer.preview(config_builder=config_builder, num_records=1) sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {} return PreviewResponse(success=True, sample=sample) except Exception as e: return PreviewResponse(success=False, error=str(e)) @app.get("/models") async def list_models(): return {"models": [ {"id": "z-ai/glm-5", "name": "GLM-5", "description": "z.ai flagship"}, {"id": "openai/gpt-4o-mini", "name": "GPT-4o Mini", "description": "Fast & cheap"}, {"id": "anthropic/claude-sonnet-4.6", "name": "Claude Sonnet 4.6", "description": "Latest Claude"} ]}