mindchain commited on
Commit
bf2ae7f
·
verified ·
1 Parent(s): e3d4ce2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import asynccontextmanager
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from typing import Any
6
+
7
+ from models import (
8
+ GenerateRequest, GenerateResponse,
9
+ PreviewRequest, PreviewResponse,
10
+ HealthResponse, ZaiModel, SamplerType
11
+ )
12
+
13
+ # Configure z.ai as OpenAI-compatible provider
14
+ ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
15
+ ZAI_BASE_URL = "https://api.z.ai/api/anthropic"
16
+
17
+ os.environ["OPENAI_API_KEY"] = ZAI_API_KEY
18
+ os.environ["OPENAI_API_BASE"] = ZAI_BASE_URL
19
+
20
+ # Global DataDesigner instance
21
+ data_designer = None
22
+
23
+
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ global data_designer
27
+ from data_designer.interface import DataDesigner
28
+ data_designer = DataDesigner()
29
+ yield
30
+
31
+
32
+ app = FastAPI(
33
+ title="NeMo DataDesigner API",
34
+ description="Synthetic data generation with NVIDIA NeMo DataDesigner and z.ai",
35
+ version="1.0.0",
36
+ lifespan=lifespan
37
+ )
38
+
39
+ app.add_middleware(
40
+ CORSMiddleware,
41
+ allow_origins=["*"],
42
+ allow_credentials=True,
43
+ allow_methods=["*"],
44
+ allow_headers=["*"],
45
+ )
46
+
47
+
48
+ def build_config(request: GenerateRequest | PreviewRequest):
49
+ """Build DataDesigner configuration from request."""
50
+ import data_designer.config as dd
51
+ from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams
52
+
53
+ config_builder = dd.DataDesignerConfigBuilder()
54
+
55
+ model_id = request.model.value
56
+
57
+ # Process columns
58
+ for col in request.columns:
59
+ if col.type == "sampler":
60
+ sampler_type_str = col.params.get("sampler_type", "CATEGORY")
61
+ sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
62
+ params_class = get_sampler_params(sampler_type, col.params)
63
+ config_builder.add_column(
64
+ dd.SamplerColumnConfig(
65
+ name=col.name,
66
+ sampler_type=sampler_type,
67
+ params=params_class,
68
+ )
69
+ )
70
+ elif col.type == "llm_text":
71
+ config_builder.add_column(
72
+ dd.LLMTextColumnConfig(
73
+ name=col.name,
74
+ model_alias="zai-model",
75
+ prompt=col.params.get("prompt", "Generate text"),
76
+ )
77
+ )
78
+ elif col.type == "llm_code":
79
+ config_builder.add_column(
80
+ dd.LLMCodeColumnConfig(
81
+ name=col.name,
82
+ model_alias="zai-model",
83
+ prompt=col.params.get("prompt", "Generate code"),
84
+ language=col.params.get("language", "python"),
85
+ )
86
+ )
87
+ elif col.type == "llm_structured":
88
+ config_builder.add_column(
89
+ dd.LLMStructuredColumnConfig(
90
+ name=col.name,
91
+ model_alias="zai-model",
92
+ prompt=col.params.get("prompt", "Generate structured data"),
93
+ schema=col.params.get("schema", {}),
94
+ )
95
+ )
96
+
97
+ # Add model config
98
+ model_config = ModelConfig(
99
+ alias="zai-model",
100
+ model=f"openai/{model_id}",
101
+ provider="openai",
102
+ inference_parameters=ChatCompletionInferenceParams(
103
+ temperature=request.temperature,
104
+ max_tokens=request.max_tokens,
105
+ ),
106
+ )
107
+ config_builder.add_model_config(model_config)
108
+
109
+ return config_builder
110
+
111
+
112
+ def get_sampler_params(sampler_type, params: dict) -> Any:
113
+ """Get appropriate sampler params based on type."""
114
+ import data_designer.config as dd
115
+
116
+ type_name = sampler_type.name if hasattr(sampler_type, 'name') else str(sampler_type)
117
+
118
+ if type_name == "CATEGORY":
119
+ return dd.CategorySamplerParams(
120
+ values=params.get("values", ["A", "B", "C"])
121
+ )
122
+ elif type_name == "UNIFORM":
123
+ return dd.UniformSamplerParams(
124
+ low=params.get("low", 0),
125
+ high=params.get("high", 100)
126
+ )
127
+ elif type_name == "GAUSSIAN":
128
+ return dd.GaussianSamplerParams(
129
+ mean=params.get("mean", 0),
130
+ std=params.get("std", 1)
131
+ )
132
+ elif type_name == "DATETIME":
133
+ return dd.DateTimeSamplerParams(
134
+ start_date=params.get("start_date", "2020-01-01"),
135
+ end_date=params.get("end_date", "2025-12-31")
136
+ )
137
+ else:
138
+ return dd.CategorySamplerParams(values=["default"])
139
+
140
+
141
+ @app.get("/", response_model=HealthResponse)
142
+ async def root():
143
+ """Health check endpoint."""
144
+ return HealthResponse(
145
+ status="healthy",
146
+ model="data-designer",
147
+ api_configured=bool(ZAI_API_KEY)
148
+ )
149
+
150
+
151
+ @app.get("/health", response_model=HealthResponse)
152
+ async def health():
153
+ """Health check endpoint."""
154
+ return HealthResponse(
155
+ status="healthy",
156
+ model="data-designer",
157
+ api_configured=bool(ZAI_API_KEY)
158
+ )
159
+
160
+
161
+ @app.post("/generate", response_model=GenerateResponse)
162
+ async def generate(request: GenerateRequest):
163
+ """
164
+ Generate synthetic data.
165
+ """
166
+ try:
167
+ config_builder = build_config(request)
168
+
169
+ result = data_designer.generate(
170
+ config_builder=config_builder,
171
+ num_records=request.num_records,
172
+ )
173
+
174
+ df = result.to_pandas()
175
+ data = df.to_dict(orient="records")
176
+
177
+ return GenerateResponse(
178
+ success=True,
179
+ data=data,
180
+ record_count=len(data)
181
+ )
182
+
183
+ except Exception as e:
184
+ return GenerateResponse(
185
+ success=False,
186
+ error=str(e)
187
+ )
188
+
189
+
190
+ @app.post("/preview", response_model=PreviewResponse)
191
+ async def preview(request: PreviewRequest):
192
+ """
193
+ Preview a single record without full generation.
194
+ """
195
+ try:
196
+ config_builder = build_config(request)
197
+
198
+ preview_result = data_designer.preview(config_builder=config_builder)
199
+
200
+ return PreviewResponse(
201
+ success=True,
202
+ sample=preview_result.sample_record
203
+ )
204
+
205
+ except Exception as e:
206
+ return PreviewResponse(
207
+ success=False,
208
+ error=str(e)
209
+ )
210
+
211
+
212
+ @app.get("/models")
213
+ async def list_models():
214
+ """List available z.ai models."""
215
+ return {
216
+ "models": [
217
+ {"id": "glm-5", "name": "GLM-5 (Opus)", "description": "Most capable model"},
218
+ {"id": "glm-4.7", "name": "GLM-4.7 (Sonnet)", "description": "Balanced performance"},
219
+ {"id": "glm-4.5-air", "name": "GLM-4.5-Air (Haiku)", "description": "Fast and efficient"}
220
+ ]
221
+ }
222
+
223
+
224
+ @app.get("/sampler-types")
225
+ async def list_sampler_types():
226
+ """List available sampler types."""
227
+ return {
228
+ "sampler_types": [
229
+ {"id": "CATEGORY", "params": ["values"]},
230
+ {"id": "UNIFORM", "params": ["low", "high"]},
231
+ {"id": "GAUSSIAN", "params": ["mean", "std"]},
232
+ {"id": "UUID", "params": []},
233
+ {"id": "DATETIME", "params": ["start_date", "end_date"]},
234
+ {"id": "PERSON", "params": ["locale", "include_attributes"]}
235
+ ]
236
+ }