mindchain commited on
Commit
ef0bf4f
·
verified ·
1 Parent(s): 52c5fca

v2.0: Direct z.ai Anthropic API integration

Browse files
Files changed (1) hide show
  1. app.py +97 -198
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
 
2
  from contextlib import asynccontextmanager
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
- from typing import Any
6
  import tempfile
 
7
 
8
  from models import (
9
  GenerateRequest, GenerateResponse,
@@ -11,30 +12,13 @@ from models import (
11
  HealthResponse, ZaiModel
12
  )
13
 
14
- # z.ai OpenAI-compatible endpoint
15
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
16
- ZAI_OPENAI_BASE = "https://api.z.ai/api/paas/v4/"
17
-
18
- # Set for LiteLLM
19
- os.environ["OPENAI_API_KEY"] = ZAI_API_KEY
20
- os.environ["OPENAI_API_BASE"] = ZAI_OPENAI_BASE
21
-
22
- data_designer = None
23
-
24
-
25
- @asynccontextmanager
26
- async def lifespan(app: FastAPI):
27
- global data_designer
28
- from data_designer.interface import DataDesigner
29
- data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
30
- yield
31
-
32
 
33
  app = FastAPI(
34
  title="NeMo DataDesigner API",
35
- description="Synthetic data generation with NVIDIA NeMo DataDesigner and z.ai",
36
- version="1.3.0",
37
- lifespan=lifespan
38
  )
39
 
40
  app.add_middleware(
@@ -46,207 +30,122 @@ app.add_middleware(
46
  )
47
 
48
 
49
- def build_config(request):
50
- import data_designer.config as dd
51
- from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams, ModelProvider
52
-
53
- config_builder = dd.DataDesignerConfigBuilder()
54
- model_id = request.model.value
55
-
56
- for col in request.columns:
57
- if col.type == "sampler":
58
- sampler_type_str = col.params.get("sampler_type", "CATEGORY")
59
- sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
60
- params = get_sampler_params(sampler_type, col.params)
61
- config_builder.add_column(
62
- dd.SamplerColumnConfig(
63
- name=col.name,
64
- sampler_type=sampler_type,
65
- params=params,
66
- )
67
- )
68
- elif col.type == "llm_text":
69
- config_builder.add_column(
70
- dd.LLMTextColumnConfig(
71
- name=col.name,
72
- model_alias="zai-model",
73
- prompt=col.params.get("prompt", "Generate text"),
74
- )
75
- )
76
-
77
- # Custom z.ai provider with OpenAI-compatible endpoint
78
- zai_provider = ModelProvider(
79
- name="zai",
80
- endpoint=ZAI_OPENAI_BASE,
81
- api_key="ZAI_API_KEY",
82
- provider_type="openai"
83
- )
84
-
85
- model_config = ModelConfig(
86
- alias="zai-model",
87
- model=model_id, # Just the model name, no prefix
88
- provider="zai",
89
- inference_parameters=ChatCompletionInferenceParams(
90
- temperature=request.temperature,
91
- max_tokens=request.max_tokens,
92
- ),
93
- )
94
-
95
- # Pass custom provider to config builder
96
- config_builder.add_model_config(model_config)
97
-
98
- return config_builder, zai_provider
99
-
100
 
101
- def get_sampler_params(sampler_type, params):
102
- import data_designer.config as dd
103
- type_name = sampler_type.name if hasattr(sampler_type, "name") else str(sampler_type)
104
 
105
- if type_name == "CATEGORY":
106
- return dd.CategorySamplerParams(values=params.get("values", ["A", "B", "C"]))
107
- elif type_name == "UNIFORM":
108
- return dd.UniformSamplerParams(low=params.get("low", 0), high=params.get("high", 100))
109
- elif type_name == "GAUSSIAN":
110
- return dd.GaussianSamplerParams(mean=params.get("mean", 0), std=params.get("std", 1))
111
- else:
112
- return dd.CategorySamplerParams(values=["default"])
113
 
114
 
115
  @app.get("/", response_model=HealthResponse)
116
  async def root():
117
- return HealthResponse(status="healthy", model="data-designer", api_configured=bool(ZAI_API_KEY))
118
 
119
 
120
  @app.get("/health", response_model=HealthResponse)
121
  async def health():
122
- return HealthResponse(status="healthy", model="data-designer", api_configured=bool(ZAI_API_KEY))
123
 
124
 
125
  @app.post("/generate", response_model=GenerateResponse)
126
  async def generate(request: GenerateRequest):
 
127
  try:
128
- from data_designer.interface import DataDesigner
129
- import data_designer.config as dd
130
- from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams, ModelProvider
131
-
132
- # Rebuild DataDesigner with custom provider
133
- zai_provider = ModelProvider(
134
- name="zai",
135
- endpoint="https://api.z.ai/api/paas/v4/",
136
- api_key="ZAI_API_KEY",
137
- provider_type="openai"
138
- )
139
 
140
- dd_custom = DataDesigner(
141
- artifact_path=tempfile.gettempdir(),
142
- model_providers=[zai_provider]
143
- )
144
 
145
- config_builder = dd.DataDesignerConfigBuilder()
146
- model_id = request.model.value
147
-
148
- for col in request.columns:
149
- if col.type == "sampler":
150
- sampler_type_str = col.params.get("sampler_type", "CATEGORY")
151
- sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
152
- params = get_sampler_params(sampler_type, col.params)
153
- config_builder.add_column(
154
- dd.SamplerColumnConfig(
155
- name=col.name,
156
- sampler_type=sampler_type,
157
- params=params,
158
- )
159
- )
160
- elif col.type == "llm_text":
161
- config_builder.add_column(
162
- dd.LLMTextColumnConfig(
163
- name=col.name,
164
- model_alias="zai-model",
165
- prompt=col.params.get("prompt", "Generate text"),
166
- )
167
  )
168
-
169
- model_config = ModelConfig(
170
- alias="zai-model",
171
- model=model_id,
172
- provider="zai",
173
- inference_parameters=ChatCompletionInferenceParams(
174
- temperature=request.temperature,
175
- max_tokens=request.max_tokens,
176
- ),
177
- )
178
- config_builder.add_model_config(model_config)
179
-
180
- result = dd_custom.create(
181
- config_builder=config_builder,
182
- num_records=request.num_records,
183
- dataset_name="api-dataset"
184
- )
185
- df = result.load_dataset()
186
- data = df.to_dict(orient="records")
187
- return GenerateResponse(success=True, data=data, record_count=len(data))
188
  except Exception as e:
189
- import traceback
190
- return GenerateResponse(success=False, error=f"{str(e)}")
191
 
192
 
193
  @app.post("/preview", response_model=PreviewResponse)
194
  async def preview(request: PreviewRequest):
 
195
  try:
196
- from data_designer.interface import DataDesigner
197
- import data_designer.config as dd
198
- from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams, ModelProvider
199
 
200
- zai_provider = ModelProvider(
201
- name="zai",
202
- endpoint="https://api.z.ai/api/paas/v4/",
203
- api_key="ZAI_API_KEY",
204
- provider_type="openai"
205
- )
206
 
207
- dd_custom = DataDesigner(
208
- artifact_path=tempfile.gettempdir(),
209
- model_providers=[zai_provider]
210
- )
211
 
212
- config_builder = dd.DataDesignerConfigBuilder()
213
- model_id = request.model.value
214
-
215
- for col in request.columns:
216
- if col.type == "sampler":
217
- sampler_type_str = col.params.get("sampler_type", "CATEGORY")
218
- sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
219
- params = get_sampler_params(sampler_type, col.params)
220
- config_builder.add_column(
221
- dd.SamplerColumnConfig(
222
- name=col.name,
223
- sampler_type=sampler_type,
224
- params=params,
225
- )
226
- )
227
- elif col.type == "llm_text":
228
- config_builder.add_column(
229
- dd.LLMTextColumnConfig(
230
- name=col.name,
231
- model_alias="zai-model",
232
- prompt=col.params.get("prompt", "Generate text"),
233
- )
234
- )
235
-
236
- model_config = ModelConfig(
237
- alias="zai-model",
238
- model=model_id,
239
- provider="zai",
240
- inference_parameters=ChatCompletionInferenceParams(
241
- temperature=request.temperature,
242
- max_tokens=request.max_tokens,
243
- ),
244
- )
245
- config_builder.add_model_config(model_config)
246
-
247
- preview_result = dd_custom.preview(config_builder=config_builder, num_records=1)
248
- sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
249
- return PreviewResponse(success=True, sample=sample)
250
  except Exception as e:
251
  return PreviewResponse(success=False, error=str(e))
252
 
@@ -254,9 +153,9 @@ async def preview(request: PreviewRequest):
254
  @app.get("/models")
255
  async def list_models():
256
  return {"models": [
257
- {"id": "glm-5", "name": "GLM-5", "description": "Most capable"},
258
- {"id": "glm-4.7", "name": "GLM-4.7", "description": "Balanced"},
259
- {"id": "glm-4.5-air", "name": "GLM-4.5-Air", "description": "Fast"}
260
  ]}
261
 
262
 
 
1
  import os
2
+ import httpx
3
  from contextlib import asynccontextmanager
4
  from fastapi import FastAPI
5
  from fastapi.middleware.cors import CORSMiddleware
 
6
  import tempfile
7
+ import random
8
 
9
  from models import (
10
  GenerateRequest, GenerateResponse,
 
12
  HealthResponse, ZaiModel
13
  )
14
 
 
15
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
16
+ ZAI_BASE_URL = "https://api.z.ai/api/anthropic"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  app = FastAPI(
19
  title="NeMo DataDesigner API",
20
+ description="Synthetic data generation with z.ai",
21
+ version="2.0.0"
 
22
  )
23
 
24
  app.add_middleware(
 
30
  )
31
 
32
 
33
+ async def call_zai(prompt: str, model: str, temperature: float, max_tokens: int) -> str:
34
+ """Call z.ai API directly with Anthropic format."""
35
+ async with httpx.AsyncClient(timeout=60.0) as client:
36
+ response = await client.post(
37
+ f"{ZAI_BASE_URL}/v1/messages",
38
+ headers={
39
+ "x-api-key": ZAI_API_KEY,
40
+ "anthropic-version": "2023-06-01",
41
+ "content-type": "application/json"
42
+ },
43
+ json={
44
+ "model": model,
45
+ "max_tokens": max_tokens,
46
+ "messages": [{"role": "user", "content": prompt}]
47
+ }
48
+ )
49
+ if response.status_code != 200:
50
+ raise Exception(f"z.ai API error: {response.status_code} - {response.text}")
51
+ data = response.json()
52
+ return data["content"][0]["text"]
53
+
54
+
55
+ def sample_value(sampler_type: str, params: dict) -> str:
56
+ """Sample a value based on sampler type."""
57
+ if sampler_type == "CATEGORY":
58
+ values = params.get("values", ["A", "B", "C"])
59
+ return random.choice(values)
60
+ elif sampler_type == "UNIFORM":
61
+ low = params.get("low", 0)
62
+ high = params.get("high", 100)
63
+ return str(random.randint(low, high))
64
+ elif sampler_type == "GAUSSIAN":
65
+ mean = params.get("mean", 0)
66
+ std = params.get("std", 1)
67
+ return str(round(random.gauss(mean, std), 2))
68
+ else:
69
+ return "default"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
 
 
71
 
72
+ def render_prompt(template: str, context: dict) -> str:
73
+ """Render prompt template with context variables."""
74
+ result = template
75
+ for key, value in context.items():
76
+ result = result.replace("{{ " + key + " }}", str(value))
77
+ result = result.replace("{{" + key + "}}", str(value))
78
+ return result
 
79
 
80
 
81
  @app.get("/", response_model=HealthResponse)
82
  async def root():
83
+ return HealthResponse(status="healthy", model="z.ai", api_configured=bool(ZAI_API_KEY))
84
 
85
 
86
  @app.get("/health", response_model=HealthResponse)
87
  async def health():
88
+ return HealthResponse(status="healthy", model="z.ai", api_configured=bool(ZAI_API_KEY))
89
 
90
 
91
  @app.post("/generate", response_model=GenerateResponse)
92
  async def generate(request: GenerateRequest):
93
+ """Generate synthetic data using z.ai API."""
94
  try:
95
+ model = request.model.value
96
+ records = []
 
 
 
 
 
 
 
 
 
97
 
98
+ sampler_cols = {c.name: c for c in request.columns if c.type == "sampler"}
99
+ llm_cols = [c for c in request.columns if c.type == "llm_text"]
 
 
100
 
101
+ for _ in range(request.num_records):
102
+ record = {}
103
+
104
+ # Generate sampler values first
105
+ for name, col in sampler_cols.items():
106
+ record[name] = sample_value(
107
+ col.params.get("sampler_type", "CATEGORY"),
108
+ col.params
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
+
111
+ # Generate LLM text using z.ai
112
+ for col in llm_cols:
113
+ prompt = render_prompt(col.params.get("prompt", "Generate text"), record)
114
+ text = await call_zai(prompt, model, request.temperature, request.max_tokens)
115
+ record[col.name] = text
116
+
117
+ records.append(record)
118
+
119
+ return GenerateResponse(success=True, data=records, record_count=len(records))
120
+
 
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
+ return GenerateResponse(success=False, error=str(e))
 
123
 
124
 
125
  @app.post("/preview", response_model=PreviewResponse)
126
  async def preview(request: PreviewRequest):
127
+ """Preview a single record."""
128
  try:
129
+ model = request.model.value
 
 
130
 
131
+ sampler_cols = {c.name: c for c in request.columns if c.type == "sampler"}
132
+ llm_cols = [c for c in request.columns if c.type == "llm_text"]
 
 
 
 
133
 
134
+ record = {}
 
 
 
135
 
136
+ for name, col in sampler_cols.items():
137
+ record[name] = sample_value(
138
+ col.params.get("sampler_type", "CATEGORY"),
139
+ col.params
140
+ )
141
+
142
+ for col in llm_cols:
143
+ prompt = render_prompt(col.params.get("prompt", "Generate text"), record)
144
+ text = await call_zai(prompt, model, request.temperature, request.max_tokens)
145
+ record[col.name] = text
146
+
147
+ return PreviewResponse(success=True, sample=record)
148
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  except Exception as e:
150
  return PreviewResponse(success=False, error=str(e))
151
 
 
153
  @app.get("/models")
154
  async def list_models():
155
  return {"models": [
156
+ {"id": "glm-5", "name": "GLM-5 (Opus)", "description": "Most capable"},
157
+ {"id": "glm-4.7", "name": "GLM-4.7 (Sonnet)", "description": "Balanced"},
158
+ {"id": "glm-4.5-air", "name": "GLM-4.5-Air (Haiku)", "description": "Fast"}
159
  ]}
160
 
161