mindchain commited on
Commit
2778014
·
verified ·
1 Parent(s): ef0bf4f

v3.0: Use LiteLLM anthropic/ prefix with ANTHROPIC_BASE_URL

Browse files
Files changed (1) hide show
  1. app.py +87 -98
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import os
2
- import httpx
3
  from contextlib import asynccontextmanager
4
  from fastapi import FastAPI
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import tempfile
7
- import random
8
 
9
  from models import (
10
  GenerateRequest, GenerateResponse,
@@ -12,13 +10,30 @@ from models import (
12
  HealthResponse, ZaiModel
13
  )
14
 
 
15
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
16
  ZAI_BASE_URL = "https://api.z.ai/api/anthropic"
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  app = FastAPI(
19
  title="NeMo DataDesigner API",
20
- description="Synthetic data generation with z.ai",
21
- version="2.0.0"
 
22
  )
23
 
24
  app.add_middleware(
@@ -30,122 +45,96 @@ app.add_middleware(
30
  )
31
 
32
 
33
- async def call_zai(prompt: str, model: str, temperature: float, max_tokens: int) -> str:
34
- """Call z.ai API directly with Anthropic format."""
35
- async with httpx.AsyncClient(timeout=60.0) as client:
36
- response = await client.post(
37
- f"{ZAI_BASE_URL}/v1/messages",
38
- headers={
39
- "x-api-key": ZAI_API_KEY,
40
- "anthropic-version": "2023-06-01",
41
- "content-type": "application/json"
42
- },
43
- json={
44
- "model": model,
45
- "max_tokens": max_tokens,
46
- "messages": [{"role": "user", "content": prompt}]
47
- }
48
- )
49
- if response.status_code != 200:
50
- raise Exception(f"z.ai API error: {response.status_code} - {response.text}")
51
- data = response.json()
52
- return data["content"][0]["text"]
53
-
54
-
55
- def sample_value(sampler_type: str, params: dict) -> str:
56
- """Sample a value based on sampler type."""
57
- if sampler_type == "CATEGORY":
58
- values = params.get("values", ["A", "B", "C"])
59
- return random.choice(values)
60
- elif sampler_type == "UNIFORM":
61
- low = params.get("low", 0)
62
- high = params.get("high", 100)
63
- return str(random.randint(low, high))
64
- elif sampler_type == "GAUSSIAN":
65
- mean = params.get("mean", 0)
66
- std = params.get("std", 1)
67
- return str(round(random.gauss(mean, std), 2))
68
- else:
69
- return "default"
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- def render_prompt(template: str, context: dict) -> str:
73
- """Render prompt template with context variables."""
74
- result = template
75
- for key, value in context.items():
76
- result = result.replace("{{ " + key + " }}", str(value))
77
- result = result.replace("{{" + key + "}}", str(value))
78
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  @app.get("/", response_model=HealthResponse)
82
  async def root():
83
- return HealthResponse(status="healthy", model="z.ai", api_configured=bool(ZAI_API_KEY))
84
 
85
 
86
  @app.get("/health", response_model=HealthResponse)
87
  async def health():
88
- return HealthResponse(status="healthy", model="z.ai", api_configured=bool(ZAI_API_KEY))
89
 
90
 
91
  @app.post("/generate", response_model=GenerateResponse)
92
  async def generate(request: GenerateRequest):
93
- """Generate synthetic data using z.ai API."""
94
  try:
95
- model = request.model.value
96
- records = []
97
-
98
- sampler_cols = {c.name: c for c in request.columns if c.type == "sampler"}
99
- llm_cols = [c for c in request.columns if c.type == "llm_text"]
100
-
101
- for _ in range(request.num_records):
102
- record = {}
103
-
104
- # Generate sampler values first
105
- for name, col in sampler_cols.items():
106
- record[name] = sample_value(
107
- col.params.get("sampler_type", "CATEGORY"),
108
- col.params
109
- )
110
-
111
- # Generate LLM text using z.ai
112
- for col in llm_cols:
113
- prompt = render_prompt(col.params.get("prompt", "Generate text"), record)
114
- text = await call_zai(prompt, model, request.temperature, request.max_tokens)
115
- record[col.name] = text
116
-
117
- records.append(record)
118
-
119
- return GenerateResponse(success=True, data=records, record_count=len(records))
120
-
121
  except Exception as e:
122
  return GenerateResponse(success=False, error=str(e))
123
 
124
 
125
  @app.post("/preview", response_model=PreviewResponse)
126
  async def preview(request: PreviewRequest):
127
- """Preview a single record."""
128
  try:
129
- model = request.model.value
130
-
131
- sampler_cols = {c.name: c for c in request.columns if c.type == "sampler"}
132
- llm_cols = [c for c in request.columns if c.type == "llm_text"]
133
-
134
- record = {}
135
-
136
- for name, col in sampler_cols.items():
137
- record[name] = sample_value(
138
- col.params.get("sampler_type", "CATEGORY"),
139
- col.params
140
- )
141
-
142
- for col in llm_cols:
143
- prompt = render_prompt(col.params.get("prompt", "Generate text"), record)
144
- text = await call_zai(prompt, model, request.temperature, request.max_tokens)
145
- record[col.name] = text
146
-
147
- return PreviewResponse(success=True, sample=record)
148
-
149
  except Exception as e:
150
  return PreviewResponse(success=False, error=str(e))
151
 
 
1
  import os
 
2
  from contextlib import asynccontextmanager
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
  import tempfile
 
6
 
7
  from models import (
8
  GenerateRequest, GenerateResponse,
 
10
  HealthResponse, ZaiModel
11
  )
12
 
13
+ # z.ai als Anthropic-Endpunkt für LiteLLM konfigurieren
14
  ZAI_API_KEY = os.environ.get("ZAI_API_KEY", "")
15
  ZAI_BASE_URL = "https://api.z.ai/api/anthropic"
16
 
17
+ # LiteLLM Anthropic-Konfiguration
18
+ os.environ["ANTHROPIC_API_KEY"] = ZAI_API_KEY
19
+ os.environ["ANTHROPIC_BASE_URL"] = ZAI_BASE_URL
20
+
21
+ data_designer = None
22
+
23
+
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ global data_designer
27
+ from data_designer.interface import DataDesigner
28
+ data_designer = DataDesigner(artifact_path=tempfile.gettempdir())
29
+ yield
30
+
31
+
32
  app = FastAPI(
33
  title="NeMo DataDesigner API",
34
+ description="Synthetic data generation with DataDesigner + z.ai",
35
+ version="3.0.0",
36
+ lifespan=lifespan
37
  )
38
 
39
  app.add_middleware(
 
45
  )
46
 
47
 
48
+ def build_config(request):
49
+ import data_designer.config as dd
50
+ from data_designer.config.models import ModelConfig, ChatCompletionInferenceParams
51
+
52
+ config_builder = dd.DataDesignerConfigBuilder()
53
+ model_id = request.model.value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ for col in request.columns:
56
+ if col.type == "sampler":
57
+ sampler_type_str = col.params.get("sampler_type", "CATEGORY")
58
+ sampler_type = getattr(dd.SamplerType, sampler_type_str, dd.SamplerType.CATEGORY)
59
+ params = get_sampler_params(sampler_type, col.params)
60
+ config_builder.add_column(
61
+ dd.SamplerColumnConfig(
62
+ name=col.name,
63
+ sampler_type=sampler_type,
64
+ params=params,
65
+ )
66
+ )
67
+ elif col.type == "llm_text":
68
+ config_builder.add_column(
69
+ dd.LLMTextColumnConfig(
70
+ name=col.name,
71
+ model_alias="zai-model",
72
+ prompt=col.params.get("prompt", "Generate text"),
73
+ )
74
+ )
75
 
76
+ # LiteLLM erkennt "anthropic/" Prefix und nutzt ANTHROPIC_BASE_URL
77
+ model_config = ModelConfig(
78
+ alias="zai-model",
79
+ model=f"anthropic/{model_id}",
80
+ provider="anthropic",
81
+ inference_parameters=ChatCompletionInferenceParams(
82
+ temperature=request.temperature,
83
+ max_tokens=request.max_tokens,
84
+ ),
85
+ )
86
+ config_builder.add_model_config(model_config)
87
+
88
+ return config_builder
89
+
90
+
91
+ def get_sampler_params(sampler_type, params):
92
+ import data_designer.config as dd
93
+ type_name = sampler_type.name if hasattr(sampler_type, "name") else str(sampler_type)
94
+
95
+ if type_name == "CATEGORY":
96
+ return dd.CategorySamplerParams(values=params.get("values", ["A", "B", "C"]))
97
+ elif type_name == "UNIFORM":
98
+ return dd.UniformSamplerParams(low=params.get("low", 0), high=params.get("high", 100))
99
+ elif type_name == "GAUSSIAN":
100
+ return dd.GaussianSamplerParams(mean=params.get("mean", 0), std=params.get("std", 1))
101
+ else:
102
+ return dd.CategorySamplerParams(values=["default"])
103
 
104
 
105
  @app.get("/", response_model=HealthResponse)
106
  async def root():
107
+ return HealthResponse(status="healthy", model="data-designer", api_configured=bool(ZAI_API_KEY))
108
 
109
 
110
  @app.get("/health", response_model=HealthResponse)
111
  async def health():
112
+ return HealthResponse(status="healthy", model="data-designer", api_configured=bool(ZAI_API_KEY))
113
 
114
 
115
  @app.post("/generate", response_model=GenerateResponse)
116
  async def generate(request: GenerateRequest):
 
117
  try:
118
+ config_builder = build_config(request)
119
+ result = data_designer.create(
120
+ config_builder=config_builder,
121
+ num_records=request.num_records,
122
+ dataset_name="api-dataset"
123
+ )
124
+ df = result.load_dataset()
125
+ data = df.to_dict(orient="records")
126
+ return GenerateResponse(success=True, data=data, record_count=len(data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  except Exception as e:
128
  return GenerateResponse(success=False, error=str(e))
129
 
130
 
131
  @app.post("/preview", response_model=PreviewResponse)
132
  async def preview(request: PreviewRequest):
 
133
  try:
134
+ config_builder = build_config(request)
135
+ preview_result = data_designer.preview(config_builder=config_builder, num_records=1)
136
+ sample = preview_result.dataset.to_dict(orient="records")[0] if len(preview_result.dataset) > 0 else {}
137
+ return PreviewResponse(success=True, sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  except Exception as e:
139
  return PreviewResponse(success=False, error=str(e))
140