mindchain commited on
Commit
f8ba930
·
verified ·
1 Parent(s): 0ef55d4

v2.0: Custom z.ai ModelProvider

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  from contextlib import asynccontextmanager
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
- from typing import Any
6
  import tempfile
7
 
8
  from models import (
@@ -11,14 +10,7 @@ from models import (
11
  HealthResponse, ZaiModel
12
  )
13
 
14
- # Configure z.ai - use OPENAI_API_BASE for custom endpoint
15
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
16
- ZAI_BASE_URL = "https://api.z.ai/api/anthropic"
17
-
18
- # For OpenAI-compatible APIs, set these env vars
19
- os.environ["OPENAI_API_KEY"] = ZAI_API_KEY
20
- os.environ["OPENAI_API_BASE"] = ZAI_BASE_URL
21
-
22
  data_designer = None
23
 
24
 
@@ -26,14 +18,27 @@ data_designer = None
26
  async def lifespan(app: FastAPI):
27
  global data_designer
28
  from data_designer.interface import DataDesigner
29
- data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  yield
31
 
32
 
33
  app = FastAPI(
34
  title="NeMo DataDesigner API",
35
- description="Synthetic data generation with NVIDIA NeMo DataDesigner and z.ai",
36
- version="1.2.0",
37
  lifespan=lifespan
38
  )
39
 
@@ -48,7 +53,7 @@ app.add_middleware(
48
 
49
  def build_config(request):
50
  import data_designer.config as dd
51
- from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams, ModelProvider
52
 
53
  config_builder = dd.DataDesignerConfigBuilder()
54
  model_id = request.model.value
@@ -74,11 +79,11 @@ def build_config(request):
74
  )
75
  )
76
 
77
- # Use openai provider - LiteLLM will use OPENAI_API_BASE env var
78
  model_config = ModelConfig(
79
  alias="zai-model",
80
  model=f"openai/{model_id}",
81
- provider="openai",
82
  inference_parameters=ChatCompletionInferenceParams(
83
  temperature=request.temperature,
84
  max_tokens=request.max_tokens,
 
2
  from contextlib import asynccontextmanager
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
 
5
  import tempfile
6
 
7
  from models import (
 
10
  HealthResponse, ZaiModel
11
  )
12
 
 
13
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
 
 
 
 
 
 
14
  data_designer = None
15
 
16
 
 
18
  async def lifespan(app: FastAPI):
19
  global data_designer
20
  from data_designer.interface import DataDesigner
21
+ from data_designer.config.models import ModelProvider
22
+
23
+ # Create custom z.ai provider
24
+ zai_provider = ModelProvider(
25
+ name="zai",
26
+ endpoint="https://api.z.ai/api/anthropic",
27
+ provider_type="openai",
28
+ api_key="ZAI_API_KEY",
29
+ )
30
+
31
+ data_designer = DataDesigner(
32
+ artifact_path=tempfile.gettempdir(),
33
+ model_providers=[zai_provider]
34
+ )
35
  yield
36
 
37
 
38
  app = FastAPI(
39
  title="NeMo DataDesigner API",
40
+ description="Synthetic data generation with z.ai",
41
+ version="2.0.0",
42
  lifespan=lifespan
43
  )
44
 
 
53
 
54
  def build_config(request):
55
  import data_designer.config as dd
56
+ from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams
57
 
58
  config_builder = dd.DataDesignerConfigBuilder()
59
  model_id = request.model.value
 
79
  )
80
  )
81
 
82
+ # Use zai provider with openai format
83
  model_config = ModelConfig(
84
  alias="zai-model",
85
  model=f"openai/{model_id}",
86
+ provider="zai",
87
  inference_parameters=ChatCompletionInferenceParams(
88
  temperature=request.temperature,
89
  max_tokens=request.max_tokens,