mindchain commited on
Commit
52c5fca
·
verified ·
1 Parent(s): f8ba930

Use z.ai OpenAI-compatible endpoint (v1.3)

Browse files
Files changed (1) hide show
  1. app.py +135 -29
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  from contextlib import asynccontextmanager
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
 
5
  import tempfile
6
 
7
  from models import (
@@ -10,7 +11,14 @@ from models import (
10
  HealthResponse, ZaiModel
11
  )
12
 
 
13
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
 
 
 
 
 
 
14
  data_designer = None
15
 
16
 
@@ -18,27 +26,14 @@ data_designer = None
18
  async def lifespan(app: FastAPI):
19
  global data_designer
20
  from data_designer.interface import DataDesigner
21
- from data_designer.config.models import ModelProvider
22
-
23
- # Create custom z.ai provider
24
- zai_provider = ModelProvider(
25
- name="zai",
26
- endpoint="https://api.z.ai/api/anthropic",
27
- provider_type="openai",
28
- api_key="ZAI_API_KEY",
29
- )
30
-
31
- data_designer = DataDesigner(
32
- artifact_path=tempfile.gettempdir(),
33
- model_providers=[zai_provider]
34
- )
35
  yield
36
 
37
 
38
  app = FastAPI(
39
  title="NeMo DataDesigner API",
40
- description="Synthetic data generation with z.ai",
41
- version="2.0.0",
42
  lifespan=lifespan
43
  )
44
 
@@ -53,7 +48,7 @@ app.add_middleware(
53
 
54
  def build_config(request):
55
  import data_designer.config as dd
56
- from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams
57
 
58
  config_builder = dd.DataDesignerConfigBuilder()
59
  model_id = request.model.value
@@ -79,19 +74,28 @@ def build_config(request):
79
  )
80
  )
81
 
82
- # Use zai provider with openai format
 
 
 
 
 
 
 
83
  model_config = ModelConfig(
84
  alias="zai-model",
85
- model=f"openai/{model_id}",
86
  provider="zai",
87
  inference_parameters=ChatCompletionInferenceParams(
88
  temperature=request.temperature,
89
  max_tokens=request.max_tokens,
90
  ),
91
  )
 
 
92
  config_builder.add_model_config(model_config)
93
-
94
- return config_builder
95
 
96
 
97
  def get_sampler_params(sampler_type, params):
@@ -121,8 +125,59 @@ async def health():
121
  @app.post("/generate", response_model=GenerateResponse)
122
  async def generate(request: GenerateRequest):
123
  try:
124
- config_builder = build_config(request)
125
- result = data_designer.create(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  config_builder=config_builder,
127
  num_records=request.num_records,
128
  dataset_name="api-dataset"
@@ -131,14 +186,65 @@ async def generate(request: GenerateRequest):
131
  data = df.to_dict(orient="records")
132
  return GenerateResponse(success=True, data=data, record_count=len(data))
133
  except Exception as e:
134
- return GenerateResponse(success=False, error=str(e))
 
135
 
136
 
137
  @app.post("/preview", response_model=PreviewResponse)
138
  async def preview(request: PreviewRequest):
139
  try:
140
- config_builder = build_config(request)
141
- preview_result = data_designer.preview(config_builder=config_builder, num_records=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
143
  return PreviewResponse(success=True, sample=sample)
144
  except Exception as e:
@@ -148,9 +254,9 @@ async def preview(request: PreviewRequest):
148
  @app.get("/models")
149
  async def list_models():
150
  return {"models": [
151
- {"id": "glm-5", "name": "GLM-5 (Opus)", "description": "Most capable model"},
152
- {"id": "glm-4.7", "name": "GLM-4.7 (Sonnet)", "description": "Balanced"},
153
- {"id": "glm-4.5-air", "name": "GLM-4.5-Air (Haiku)", "description": "Fast"}
154
  ]}
155
 
156
 
 
2
  from contextlib import asynccontextmanager
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
+ from typing import Any
6
  import tempfile
7
 
8
  from models import (
 
11
  HealthResponse, ZaiModel
12
  )
13
 
14
+ # z.ai OpenAI-compatible endpoint
15
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
16
+ ZAI_OPENAI_BASE = "https://api.z.ai/api/paas/v4/"
17
+
18
+ # Set for LiteLLM
19
+ os.environ["OPENAI_API_KEY"] = ZAI_API_KEY
20
+ os.environ["OPENAI_API_BASE"] = ZAI_OPENAI_BASE
21
+
22
  data_designer = None
23
 
24
 
 
26
  async def lifespan(app: FastAPI):
27
  global data_designer
28
  from data_designer.interface import DataDesigner
29
+ data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  yield
31
 
32
 
33
  app = FastAPI(
34
  title="NeMo DataDesigner API",
35
+ description="Synthetic data generation with NVIDIA NeMo DataDesigner and z.ai",
36
+ version="1.3.0",
37
  lifespan=lifespan
38
  )
39
 
 
48
 
49
  def build_config(request):
50
  import data_designer.config as dd
51
+ from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams, ModelProvider
52
 
53
  config_builder = dd.DataDesignerConfigBuilder()
54
  model_id = request.model.value
 
74
  )
75
  )
76
 
77
+ # Custom z.ai provider with OpenAI-compatible endpoint
78
+ zai_provider = ModelProvider(
79
+ name="zai",
80
+ endpoint=ZAI_OPENAI_BASE,
81
+ api_key="ZAI_API_KEY",
82
+ provider_type="openai"
83
+ )
84
+
85
  model_config = ModelConfig(
86
  alias="zai-model",
87
+ model=model_id, # Just the model name, no prefix
88
  provider="zai",
89
  inference_parameters=ChatCompletionInferenceParams(
90
  temperature=request.temperature,
91
  max_tokens=request.max_tokens,
92
  ),
93
  )
94
+
95
+ # Pass custom provider to config builder
96
  config_builder.add_model_config(model_config)
97
+
98
+ return config_builder, zai_provider
99
 
100
 
101
  def get_sampler_params(sampler_type, params):
 
125
  @app.post("/generate", response_model=GenerateResponse)
126
  async def generate(request: GenerateRequest):
127
  try:
128
+ from data_designer.interface import DataDesigner
129
+ import data_designer.config as dd
130
+ from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams, ModelProvider
131
+
132
+ # Rebuild DataDesigner with custom provider
133
+ zai_provider = ModelProvider(
134
+ name="zai",
135
+ endpoint="https://api.z.ai/api/paas/v4/",
136
+ api_key="ZAI_API_KEY",
137
+ provider_type="openai"
138
+ )
139
+
140
+ dd_custom = DataDesigner(
141
+ artifact_path=tempfile.gettempdir(),
142
+ model_providers=[zai_provider]
143
+ )
144
+
145
+ config_builder = dd.DataDesignerConfigBuilder()
146
+ model_id = request.model.value
147
+
148
+ for col in request.columns:
149
+ if col.type == "sampler":
150
+ sampler_type_str = col.params.get("sampler_type", "CATEGORY")
151
+ sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
152
+ params = get_sampler_params(sampler_type, col.params)
153
+ config_builder.add_column(
154
+ dd.SamplerColumnConfig(
155
+ name=col.name,
156
+ sampler_type=sampler_type,
157
+ params=params,
158
+ )
159
+ )
160
+ elif col.type == "llm_text":
161
+ config_builder.add_column(
162
+ dd.LLMTextColumnConfig(
163
+ name=col.name,
164
+ model_alias="zai-model",
165
+ prompt=col.params.get("prompt", "Generate text"),
166
+ )
167
+ )
168
+
169
+ model_config = ModelConfig(
170
+ alias="zai-model",
171
+ model=model_id,
172
+ provider="zai",
173
+ inference_parameters=ChatCompletionInferenceParams(
174
+ temperature=request.temperature,
175
+ max_tokens=request.max_tokens,
176
+ ),
177
+ )
178
+ config_builder.add_model_config(model_config)
179
+
180
+ result = dd_custom.create(
181
  config_builder=config_builder,
182
  num_records=request.num_records,
183
  dataset_name="api-dataset"
 
186
  data = df.to_dict(orient="records")
187
  return GenerateResponse(success=True, data=data, record_count=len(data))
188
  except Exception as e:
189
+ import traceback
190
+ return GenerateResponse(success=False, error=f"{str(e)}")
191
 
192
 
193
  @app.post("/preview", response_model=PreviewResponse)
194
  async def preview(request: PreviewRequest):
195
  try:
196
+ from data_designer.interface import DataDesigner
197
+ import data_designer.config as dd
198
+ from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams, ModelProvider
199
+
200
+ zai_provider = ModelProvider(
201
+ name="zai",
202
+ endpoint="https://api.z.ai/api/paas/v4/",
203
+ api_key="ZAI_API_KEY",
204
+ provider_type="openai"
205
+ )
206
+
207
+ dd_custom = DataDesigner(
208
+ artifact_path=tempfile.gettempdir(),
209
+ model_providers=[zai_provider]
210
+ )
211
+
212
+ config_builder = dd.DataDesignerConfigBuilder()
213
+ model_id = request.model.value
214
+
215
+ for col in request.columns:
216
+ if col.type == "sampler":
217
+ sampler_type_str = col.params.get("sampler_type", "CATEGORY")
218
+ sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
219
+ params = get_sampler_params(sampler_type, col.params)
220
+ config_builder.add_column(
221
+ dd.SamplerColumnConfig(
222
+ name=col.name,
223
+ sampler_type=sampler_type,
224
+ params=params,
225
+ )
226
+ )
227
+ elif col.type == "llm_text":
228
+ config_builder.add_column(
229
+ dd.LLMTextColumnConfig(
230
+ name=col.name,
231
+ model_alias="zai-model",
232
+ prompt=col.params.get("prompt", "Generate text"),
233
+ )
234
+ )
235
+
236
+ model_config = ModelConfig(
237
+ alias="zai-model",
238
+ model=model_id,
239
+ provider="zai",
240
+ inference_parameters=ChatCompletionInferenceParams(
241
+ temperature=request.temperature,
242
+ max_tokens=request.max_tokens,
243
+ ),
244
+ )
245
+ config_builder.add_model_config(model_config)
246
+
247
+ preview_result = dd_custom.preview(config_builder=config_builder, num_records=1)
248
  sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
249
  return PreviewResponse(success=True, sample=sample)
250
  except Exception as e:
 
254
  @app.get("/models")
255
  async def list_models():
256
  return {"models": [
257
+ {"id": "glm-5", "name": "GLM-5", "description": "Most capable"},
258
+ {"id": "glm-4.7", "name": "GLM-4.7", "description": "Balanced"},
259
+ {"id": "glm-4.5-air", "name": "GLM-4.5-Air", "description": "Fast"}
260
  ]}
261
 
262