mindchain commited on
Commit
389fe48
·
verified ·
1 Parent(s): c6540a4

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -12,12 +12,13 @@ from models import (
12
  HealthResponse, ZaiModel, SamplerType
13
  )
14
 
15
- # Configure z.ai as OpenAI-compatible provider
16
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
17
  ZAI_BASE_URL = "https://api.z.ai/api/anthropic"
18
 
19
- os.environ["OPENAI_API_KEY"] = ZAI_API_KEY
20
- os.environ["OPENAI_API_BASE"] = ZAI_BASE_URL
 
21
 
22
  # Global DataDesigner instance
23
  data_designer = None
@@ -27,7 +28,6 @@ data_designer = None
27
  async def lifespan(app: FastAPI):
28
  global data_designer
29
  from data_designer.interface import DataDesigner
30
- # Use temp directory for artifacts
31
  data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
32
  yield
33
 
@@ -97,11 +97,11 @@ def build_config(request: GenerateRequest | PreviewRequest):
97
  )
98
  )
99
 
100
- # Add model config
101
  model_config = ModelConfig(
102
  alias="zai-model",
103
- model=f"openai/{model_id}",
104
- provider="openai",
105
  inference_parameters=ChatCompletionInferenceParams(
106
  temperature=request.temperature,
107
  max_tokens=request.max_tokens,
@@ -169,14 +169,12 @@ async def generate(request: GenerateRequest):
169
  try:
170
  config_builder = build_config(request)
171
 
172
- # Use create() method - returns DatasetCreationResults
173
  result = data_designer.create(
174
  config_builder=config_builder,
175
  num_records=request.num_records,
176
  dataset_name="api-dataset"
177
  )
178
 
179
- # Get DataFrame from results
180
  df = result.load_dataset()
181
  data = df.to_dict(orient="records")
182
 
@@ -190,7 +188,7 @@ async def generate(request: GenerateRequest):
190
  import traceback
191
  return GenerateResponse(
192
  success=False,
193
- error=f"{str(e)}\n{traceback.format_exc()}"
194
  )
195
 
196
 
@@ -202,13 +200,11 @@ async def preview(request: PreviewRequest):
202
  try:
203
  config_builder = build_config(request)
204
 
205
- # Use preview() method - returns PreviewResults
206
  preview_result = data_designer.preview(
207
  config_builder=config_builder,
208
  num_records=1
209
  )
210
 
211
- # Get sample record
212
  sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
213
 
214
  return PreviewResponse(
@@ -220,7 +216,7 @@ async def preview(request: PreviewRequest):
220
  import traceback
221
  return PreviewResponse(
222
  success=False,
223
- error=f"{str(e)}\n{traceback.format_exc()}"
224
  )
225
 
226
 
 
12
  HealthResponse, ZaiModel, SamplerType
13
  )
14
 
15
+ # Configure z.ai as Anthropic-compatible provider for LiteLLM
16
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
17
  ZAI_BASE_URL = "https://api.z.ai/api/anthropic"
18
 
19
+ # LiteLLM uses these env vars for custom Anthropic endpoint
20
+ os.environ["ANTHROPIC_API_KEY"] = ZAI_API_KEY
21
+ os.environ["ANTHROPIC_BASE_URL"] = ZAI_BASE_URL
22
 
23
  # Global DataDesigner instance
24
  data_designer = None
 
28
  async def lifespan(app: FastAPI):
29
  global data_designer
30
  from data_designer.interface import DataDesigner
 
31
  data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
32
  yield
33
 
 
97
  )
98
  )
99
 
100
+ # Add model config - use anthropic provider with custom base_url via env
101
  model_config = ModelConfig(
102
  alias="zai-model",
103
+ model=f"anthropic/{model_id}", # LiteLLM Anthropic format
104
+ provider="anthropic",
105
  inference_parameters=ChatCompletionInferenceParams(
106
  temperature=request.temperature,
107
  max_tokens=request.max_tokens,
 
169
  try:
170
  config_builder = build_config(request)
171
 
 
172
  result = data_designer.create(
173
  config_builder=config_builder,
174
  num_records=request.num_records,
175
  dataset_name="api-dataset"
176
  )
177
 
 
178
  df = result.load_dataset()
179
  data = df.to_dict(orient="records")
180
 
 
188
  import traceback
189
  return GenerateResponse(
190
  success=False,
191
+ error=f"{str(e)}"
192
  )
193
 
194
 
 
200
  try:
201
  config_builder = build_config(request)
202
 
 
203
  preview_result = data_designer.preview(
204
  config_builder=config_builder,
205
  num_records=1
206
  )
207
 
 
208
  sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
209
 
210
  return PreviewResponse(
 
216
  import traceback
217
  return PreviewResponse(
218
  success=False,
219
+ error=f"{str(e)}"
220
  )
221
 
222