mindchain commited on
Commit
eae7491
·
verified ·
1 Parent(s): 389fe48

Fix: Use Anthropic provider for z.ai API

Browse files
Files changed (1) hide show
  1. app.py +20 -83
app.py CHANGED
@@ -4,23 +4,21 @@ from fastapi import FastAPI, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from typing import Any
6
  import tempfile
7
- from pathlib import Path
8
 
9
  from models import (
10
  GenerateRequest, GenerateResponse,
11
  PreviewRequest, PreviewResponse,
12
- HealthResponse, ZaiModel, SamplerType
13
  )
14
 
15
- # Configure z.ai as Anthropic-compatible provider for LiteLLM
16
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
17
- ZAI_BASE_URL = "https://api.z.ai/api/anthropic"
18
 
19
- # LiteLLM uses these env vars for custom Anthropic endpoint
20
- os.environ["ANTHROPIC_API_KEY"] = ZAI_API_KEY
21
- os.environ["ANTHROPIC_BASE_URL"] = ZAI_BASE_URL
22
 
23
- # Global DataDesigner instance
24
  data_designer = None
25
 
26
 
@@ -54,10 +52,8 @@ def build_config(request: GenerateRequest | PreviewRequest):
54
  from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams
55
 
56
  config_builder = dd.DataDesignerConfigBuilder()
57
-
58
  model_id = request.model.value
59
 
60
- # Process columns
61
  for col in request.columns:
62
  if col.type == "sampler":
63
  sampler_type_str = col.params.get("sampler_type", "CATEGORY")
@@ -97,11 +93,10 @@ def build_config(request: GenerateRequest | PreviewRequest):
97
  )
98
  )
99
 
100
- # Add model config - use anthropic provider with custom base_url via env
101
  model_config = ModelConfig(
102
  alias="zai-model",
103
- model=f"anthropic/{model_id}", # LiteLLM Anthropic format
104
- provider="anthropic",
105
  inference_parameters=ChatCompletionInferenceParams(
106
  temperature=request.temperature,
107
  max_tokens=request.max_tokens,
@@ -113,116 +108,60 @@ def build_config(request: GenerateRequest | PreviewRequest):
113
 
114
 
115
  def get_sampler_params(sampler_type, params: dict) -> Any:
116
- """Get appropriate sampler params based on type."""
117
  import data_designer.config as dd
118
-
119
  type_name = sampler_type.name if hasattr(sampler_type, 'name') else str(sampler_type)
120
 
121
  if type_name == "CATEGORY":
122
- return dd.CategorySamplerParams(
123
- values=params.get("values", ["A", "B", "C"])
124
- )
125
  elif type_name == "UNIFORM":
126
- return dd.UniformSamplerParams(
127
- low=params.get("low", 0),
128
- high=params.get("high", 100)
129
- )
130
  elif type_name == "GAUSSIAN":
131
- return dd.GaussianSamplerParams(
132
- mean=params.get("mean", 0),
133
- std=params.get("std", 1)
134
- )
135
  elif type_name == "DATETIME":
136
- return dd.DateTimeSamplerParams(
137
- start_date=params.get("start_date", "2020-01-01"),
138
- end_date=params.get("end_date", "2025-12-31")
139
- )
140
  else:
141
  return dd.CategorySamplerParams(values=["default"])
142
 
143
 
144
  @app.get("/", response_model=HealthResponse)
145
  async def root():
146
- """Health check endpoint."""
147
- return HealthResponse(
148
- status="healthy",
149
- model="data-designer",
150
- api_configured=bool(ZAI_API_KEY)
151
- )
152
 
153
 
154
  @app.get("/health", response_model=HealthResponse)
155
  async def health():
156
- """Health check endpoint."""
157
- return HealthResponse(
158
- status="healthy",
159
- model="data-designer",
160
- api_configured=bool(ZAI_API_KEY)
161
- )
162
 
163
 
164
  @app.post("/generate", response_model=GenerateResponse)
165
  async def generate(request: GenerateRequest):
166
- """
167
- Generate synthetic data using DataDesigner.create().
168
- """
169
  try:
170
  config_builder = build_config(request)
171
-
172
  result = data_designer.create(
173
  config_builder=config_builder,
174
  num_records=request.num_records,
175
  dataset_name="api-dataset"
176
  )
177
-
178
  df = result.load_dataset()
179
  data = df.to_dict(orient="records")
180
-
181
- return GenerateResponse(
182
- success=True,
183
- data=data,
184
- record_count=len(data)
185
- )
186
-
187
  except Exception as e:
188
- import traceback
189
- return GenerateResponse(
190
- success=False,
191
- error=f"{str(e)}"
192
- )
193
 
194
 
195
  @app.post("/preview", response_model=PreviewResponse)
196
  async def preview(request: PreviewRequest):
197
- """
198
- Preview a single record using DataDesigner.preview().
199
- """
200
  try:
201
  config_builder = build_config(request)
202
-
203
- preview_result = data_designer.preview(
204
- config_builder=config_builder,
205
- num_records=1
206
- )
207
-
208
  sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
209
-
210
- return PreviewResponse(
211
- success=True,
212
- sample=sample
213
- )
214
-
215
  except Exception as e:
216
- import traceback
217
- return PreviewResponse(
218
- success=False,
219
- error=f"{str(e)}"
220
- )
221
 
222
 
223
  @app.get("/models")
224
  async def list_models():
225
- """List available z.ai models."""
226
  return {
227
  "models": [
228
  {"id": "glm-5", "name": "GLM-5 (Opus)", "description": "Most capable model"},
@@ -234,14 +173,12 @@ async def list_models():
234
 
235
  @app.get("/sampler-types")
236
  async def list_sampler_types():
237
- """List available sampler types."""
238
  return {
239
  "sampler_types": [
240
  {"id": "CATEGORY", "params": ["values"]},
241
  {"id": "UNIFORM", "params": ["low", "high"]},
242
  {"id": "GAUSSIAN", "params": ["mean", "std"]},
243
  {"id": "UUID", "params": []},
244
- {"id": "DATETIME", "params": ["start_date", "end_date"]},
245
- {"id": "PERSON", "params": ["locale", "include_attributes"]}
246
  ]
247
  }
 
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from typing import Any
6
  import tempfile
 
7
 
8
  from models import (
9
  GenerateRequest, GenerateResponse,
10
  PreviewRequest, PreviewResponse,
11
+ HealthResponse, ZaiModel
12
  )
13
 
14
+ # Configure z.ai as OpenAI-compatible provider
15
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
16
+ ZAI_BASE_URL = "https://api.z.ai/api/openai"
17
 
18
+ # LiteLLM OpenAI-compatible config
19
+ os.environ["OPENAI_API_KEY"] = ZAI_API_KEY
20
+ os.environ["OPENAI_API_BASE"] = ZAI_BASE_URL
21
 
 
22
  data_designer = None
23
 
24
 
 
52
  from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams
53
 
54
  config_builder = dd.DataDesignerConfigBuilder()
 
55
  model_id = request.model.value
56
 
 
57
  for col in request.columns:
58
  if col.type == "sampler":
59
  sampler_type_str = col.params.get("sampler_type", "CATEGORY")
 
93
  )
94
  )
95
 
 
96
  model_config = ModelConfig(
97
  alias="zai-model",
98
+ model=f"openai/{model_id}",
99
+ provider="openai",
100
  inference_parameters=ChatCompletionInferenceParams(
101
  temperature=request.temperature,
102
  max_tokens=request.max_tokens,
 
108
 
109
 
110
  def get_sampler_params(sampler_type, params: dict) -> Any:
 
111
  import data_designer.config as dd
 
112
  type_name = sampler_type.name if hasattr(sampler_type, 'name') else str(sampler_type)
113
 
114
  if type_name == "CATEGORY":
115
+ return dd.CategorySamplerParams(values=params.get("values", ["A", "B", "C"]))
 
 
116
  elif type_name == "UNIFORM":
117
+ return dd.UniformSamplerParams(low=params.get("low", 0), high=params.get("high", 100))
 
 
 
118
  elif type_name == "GAUSSIAN":
119
+ return dd.GaussianSamplerParams(mean=params.get("mean", 0), std=params.get("std", 1))
 
 
 
120
  elif type_name == "DATETIME":
121
+ return dd.DateTimeSamplerParams(start_date=params.get("start_date", "2020-01-01"), end_date=params.get("end_date", "2025-12-31"))
 
 
 
122
  else:
123
  return dd.CategorySamplerParams(values=["default"])
124
 
125
 
126
  @app.get("/", response_model=HealthResponse)
127
  async def root():
128
+ return HealthResponse(status="healthy", model="data-designer", api_configured=bool(ZAI_API_KEY))
 
 
 
 
 
129
 
130
 
131
  @app.get("/health", response_model=HealthResponse)
132
  async def health():
133
+ return HealthResponse(status="healthy", model="data-designer", api_configured=bool(ZAI_API_KEY))
 
 
 
 
 
134
 
135
 
136
  @app.post("/generate", response_model=GenerateResponse)
137
  async def generate(request: GenerateRequest):
 
 
 
138
  try:
139
  config_builder = build_config(request)
 
140
  result = data_designer.create(
141
  config_builder=config_builder,
142
  num_records=request.num_records,
143
  dataset_name="api-dataset"
144
  )
 
145
  df = result.load_dataset()
146
  data = df.to_dict(orient="records")
147
+ return GenerateResponse(success=True, data=data, record_count=len(data))
 
 
 
 
 
 
148
  except Exception as e:
149
+ return GenerateResponse(success=False, error=str(e))
 
 
 
 
150
 
151
 
152
  @app.post("/preview", response_model=PreviewResponse)
153
  async def preview(request: PreviewRequest):
 
 
 
154
  try:
155
  config_builder = build_config(request)
156
+ preview_result = data_designer.preview(config_builder=config_builder, num_records=1)
 
 
 
 
 
157
  sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
158
+ return PreviewResponse(success=True, sample=sample)
 
 
 
 
 
159
  except Exception as e:
160
+ return PreviewResponse(success=False, error=str(e))
 
 
 
 
161
 
162
 
163
  @app.get("/models")
164
  async def list_models():
 
165
  return {
166
  "models": [
167
  {"id": "glm-5", "name": "GLM-5 (Opus)", "description": "Most capable model"},
 
173
 
174
  @app.get("/sampler-types")
175
  async def list_sampler_types():
 
176
  return {
177
  "sampler_types": [
178
  {"id": "CATEGORY", "params": ["values"]},
179
  {"id": "UNIFORM", "params": ["low", "high"]},
180
  {"id": "GAUSSIAN", "params": ["mean", "std"]},
181
  {"id": "UUID", "params": []},
182
+ {"id": "DATETIME", "params": ["start_date", "end_date"]}
 
183
  ]
184
  }