Cialtion commited on
Commit
7fdb7d3
Β·
verified Β·
1 Parent(s): 420ec60

Update 02_server.py

Browse files
Files changed (1) hide show
  1. 02_server.py +336 -0
02_server.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SimpleTool vLLM Server - Multi-Head Parallel Decoding for Real-Time Function Calling
4
+ Supports both v1 and v2 prompt formats. HTML clients need zero changes.
5
+ """
6
+
7
+ import json
8
+ import time
9
+ import os
10
+ from typing import List, Dict, Any, Optional
11
+ from contextlib import asynccontextmanager
12
+
13
+ from fastapi import FastAPI, HTTPException
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel
16
+ import uvicorn
17
+
18
+ from vllm import LLM, SamplingParams
19
+
20
+ # ==================== Config ====================
21
+ MODEL_PATH = "./models/RT-Qwen3-4B-AWQ-v2" # v2 model path
22
+ MODEL_VERSION = "v2" # "v1" or "v2"
23
+ SERVER_HOST = "0.0.0.0"
24
+ SERVER_PORT = 8899
25
+ MAX_HISTORY = 6
26
+
27
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
28
+
29
+ # ==================== Multi-Head Tags ====================
30
+ HEAD_TAGS = ["<content>", "<function>", "<arg1>", "<arg2>", "<arg3>", "<arg4>", "<arg5>", "<arg6>"]
31
+ STOP_TOKENS = ["<|null|>", "</content>", "</function>", "</arg1>", "</arg2>", "</arg3>", "</arg4>", "</arg5>", "</arg6>", "<|im_end|>"]
32
+
33
+ # ── v1: generic head-format instructions in system, domain context in user ──
34
+ V1_SYSTEM_TEMPLATE = """<|im_start|>system
35
+ You are a multi-head parallel function calling model.
36
+ ## Output Heads
37
+
38
+ **Head 0 - <content>**: Natural language response
39
+ - Format: <content>response text</content>
40
+
41
+ **Head 1 - <function>**: Function names to call
42
+ - Format: <function>name</function>
43
+
44
+ **Head 2-7 - <arg1>-<arg6>**: Function arguments by position
45
+ - Format: <argN>value</argN>
46
+ - If Unnecessary: <argN><|null|></argN>
47
+
48
+ ## Available Tools:
49
+
50
+ {tools_json}
51
+ <|im_end|>
52
+ """
53
+
54
+ V1_USER_TEMPLATE = "<|im_start|>user\nenvironment: {env}\nhistory: [{hist}]\n\n{query}<|im_end|>\n<|im_start|>assistant\n"
55
+
56
+ # ── v2: domain system prompt + tools in system, leaner user turn ──
57
+ V2_SYSTEM_TEMPLATE = """<|im_start|>system
58
+ {system_prompt}
59
+
60
+ ## Available Tools:
61
+
62
+ {tools_json}
63
+ <|im_end|>
64
+ """
65
+
66
+ V2_USER_TEMPLATE = "<|im_start|>user\nhistory: [{hist}]\n\n{query}<|im_end|>\n<|im_start|>assistant\n"
67
+
68
+ # Default system prompt when HTML client doesn't send one (backward compat)
69
+ V2_DEFAULT_SYSTEM = "You are a real-time function calling assistant. Convert user commands into function calls using the available tools."
70
+
71
+
72
+ # ==================== Data Models ====================
73
+ class Message(BaseModel):
74
+ role: str
75
+ content: str
76
+
77
+
78
+ class FCRequest(BaseModel):
79
+ messages: List[Message]
80
+ tools: List[Dict[str, Any]]
81
+ # ── v1 fields (still accepted, used when version=v1) ──
82
+ environment: Optional[List[str]] = None
83
+ history: Optional[List[str]] = None
84
+ # ── v2 optional: domain system prompt ──
85
+ system: Optional[str] = None
86
+ # ── shared ──
87
+ max_tokens: int = 32
88
+ temperature: float = 0.0
89
+ include_content_head: bool = False
90
+
91
+
92
+ class FCResponse(BaseModel):
93
+ success: bool
94
+ function: Optional[str] = None
95
+ args: Dict[str, Any] = {}
96
+ heads: Dict[str, str] = {}
97
+ content: Optional[str] = None
98
+ latency_ms: float = 0
99
+ error: Optional[str] = None
100
+
101
+
102
+ # ==================== SimpleTool Engine ====================
103
+ class SimpleToolEngine:
104
+ def __init__(self, model_path: str, version: str = "v2"):
105
+ self.model_path = model_path
106
+ self.version = version
107
+ self.llm: Optional[LLM] = None
108
+ self.sampling_params = None
109
+
110
+ def initialize(self):
111
+ print(f"[SimpleTool] Loading model ({self.version}): {self.model_path}")
112
+ self.llm = LLM(
113
+ model=self.model_path,
114
+ trust_remote_code=True,
115
+ enable_prefix_caching=True,
116
+ tensor_parallel_size=1,
117
+ gpu_memory_utilization=0.8,
118
+ max_model_len=4096,
119
+ dtype="auto",
120
+ )
121
+ self.sampling_params = SamplingParams(
122
+ temperature=0.0,
123
+ max_tokens=32,
124
+ stop=STOP_TOKENS,
125
+ include_stop_str_in_output=True
126
+ )
127
+ print(f"[SimpleTool] Model loaded! (version={self.version})")
128
+ self._warmup()
129
+
130
+ def _warmup(self):
131
+ print("[SimpleTool] Warming up...")
132
+ dummy_tools = '{"type":"function","function":{"name":"test","parameters":{}}}'
133
+ if self.version == "v1":
134
+ prefix = V1_SYSTEM_TEMPLATE.format(tools_json=dummy_tools)
135
+ prefix += V1_USER_TEMPLATE.format(env="[]", hist="", query="test")
136
+ else:
137
+ prefix = V2_SYSTEM_TEMPLATE.format(system_prompt=V2_DEFAULT_SYSTEM, tools_json=dummy_tools)
138
+ prefix += V2_USER_TEMPLATE.format(hist="", query="test")
139
+ prompts = [prefix + tag for tag in HEAD_TAGS[:2]] # function + arg1 enough
140
+ self.llm.generate(prompts, self.sampling_params)
141
+ print("[SimpleTool] Warmup complete!")
142
+
143
+ def _build_tools_json(self, tools: List[Dict]) -> str:
144
+ return "\n".join(json.dumps(t, ensure_ascii=False) for t in tools)
145
+
146
+ def _extract_param_info(self, tools: List[Dict]) -> List[str]:
147
+ names = []
148
+ for tool in tools:
149
+ func = tool.get("function", {})
150
+ params = func.get("parameters", {}).get("properties", {})
151
+ for name in params.keys():
152
+ if name not in names:
153
+ names.append(name)
154
+ return names[:6]
155
+
156
+ def _get_max_args(self, tools: List[Dict]) -> int:
157
+ max_args = 0
158
+ for tool in tools:
159
+ func = tool.get("function", {})
160
+ params = func.get("parameters", {}).get("properties", {})
161
+ max_args = max(max_args, len(params))
162
+ return min(max_args, 6)
163
+
164
+ def _build_prompt(self, request: FCRequest) -> str:
165
+ """Build the shared prefix according to version."""
166
+ tools_json = self._build_tools_json(request.tools)
167
+
168
+ # Extract query from messages
169
+ query = ""
170
+ for msg in request.messages:
171
+ if msg.role == "user":
172
+ query = msg.content
173
+
174
+ hist_list = (request.history or [])[-MAX_HISTORY:]
175
+ hist_str = ", ".join(hist_list) if hist_list else ""
176
+
177
+ if self.version == "v1":
178
+ # ── v1: head descriptions + tools in system, env+history+query in user ──
179
+ env_str = json.dumps(request.environment or [], ensure_ascii=False)
180
+ system_part = V1_SYSTEM_TEMPLATE.format(tools_json=tools_json)
181
+ user_part = V1_USER_TEMPLATE.format(env=env_str, hist=hist_str, query=query)
182
+ else:
183
+ # ── v2: domain system + tools in system, history+query in user ──
184
+ # If client sends a system prompt, use it; otherwise use default.
185
+ # For legacy HTML clients that send environment[], fold it into query.
186
+ system_prompt = request.system or V2_DEFAULT_SYSTEM
187
+ system_part = V2_SYSTEM_TEMPLATE.format(
188
+ system_prompt=system_prompt,
189
+ tools_json=tools_json
190
+ )
191
+ # Backward compat: if environment is provided (old HTML clients),
192
+ # prepend it to the query so the model still sees context.
193
+ env_prefix = ""
194
+ if request.environment:
195
+ env_prefix = "environment: " + json.dumps(request.environment, ensure_ascii=False) + "\n"
196
+ user_part = V2_USER_TEMPLATE.format(
197
+ hist=hist_str,
198
+ query=env_prefix + query
199
+ )
200
+
201
+ return system_part + user_part
202
+
203
+ def call(self, request: FCRequest) -> FCResponse:
204
+ start = time.perf_counter()
205
+
206
+ full_prefix = self._build_prompt(request)
207
+
208
+ # Dynamic head selection based on max args
209
+ max_args = self._get_max_args(request.tools)
210
+ active_tags = ["<function>"] + [f"<arg{i}>" for i in range(1, max_args + 1)]
211
+ if request.include_content_head:
212
+ active_tags = ["<content>"] + active_tags
213
+
214
+ prompts = [full_prefix + tag for tag in active_tags]
215
+ outputs = self.llm.generate(prompts, self.sampling_params)
216
+
217
+ latency_ms = (time.perf_counter() - start) * 1000
218
+
219
+ # Parse outputs
220
+ heads = {}
221
+ head_names = []
222
+ if request.include_content_head:
223
+ head_names.append("content")
224
+ head_names.append("function")
225
+ head_names.extend([f"arg{i}" for i in range(1, max_args + 1)])
226
+
227
+ for i, output in enumerate(outputs):
228
+ text = output.outputs[0].text.strip()
229
+ for stop in STOP_TOKENS:
230
+ if text.endswith(stop):
231
+ text = text[:-len(stop)].strip()
232
+ break
233
+ heads[head_names[i]] = text
234
+
235
+ func_name = heads.get("function", "").strip()
236
+ if not func_name or func_name == "<|null|>":
237
+ return FCResponse(
238
+ success=False,
239
+ heads=heads,
240
+ content=heads.get("content"),
241
+ latency_ms=latency_ms,
242
+ error="No function called"
243
+ )
244
+
245
+ param_names = self._extract_param_info(request.tools)
246
+ args = {}
247
+ for i, name in enumerate(param_names):
248
+ val = heads.get(f"arg{i+1}", "").strip()
249
+ if val and val != "<|null|>":
250
+ if val.isdigit():
251
+ args[name] = int(val)
252
+ elif val.lstrip('-').replace('.', '', 1).isdigit():
253
+ args[name] = float(val)
254
+ else:
255
+ args[name] = val.lower().strip()
256
+
257
+ return FCResponse(
258
+ success=True,
259
+ function=func_name,
260
+ args=args,
261
+ heads=heads,
262
+ content=heads.get("content"),
263
+ latency_ms=latency_ms
264
+ )
265
+
266
+
267
+ # ==================== FastAPI ====================
268
+ engine: Optional[SimpleToolEngine] = None
269
+
270
+
271
+ @asynccontextmanager
272
+ async def lifespan(app: FastAPI):
273
+ global engine
274
+ engine = SimpleToolEngine(MODEL_PATH, version=MODEL_VERSION)
275
+ engine.initialize()
276
+ yield
277
+ print("[Server] Shutdown")
278
+
279
+
280
+ app = FastAPI(title="SimpleTool Server", version="2.0.0", lifespan=lifespan)
281
+
282
+ app.add_middleware(
283
+ CORSMiddleware,
284
+ allow_origins=["*"],
285
+ allow_credentials=True,
286
+ allow_methods=["*"],
287
+ allow_headers=["*"],
288
+ )
289
+
290
+
291
+ @app.get("/health")
292
+ async def health():
293
+ return {
294
+ "status": "ok",
295
+ "loaded": engine is not None and engine.llm is not None,
296
+ "model": MODEL_PATH,
297
+ "version": MODEL_VERSION,
298
+ }
299
+
300
+
301
+ @app.post("/v1/function_call", response_model=FCResponse)
302
+ async def function_call(request: FCRequest):
303
+ if engine is None or engine.llm is None:
304
+ raise HTTPException(503, "Model not loaded")
305
+ try:
306
+ return engine.call(request)
307
+ except Exception as e:
308
+ import traceback
309
+ traceback.print_exc()
310
+ return FCResponse(success=False, error=str(e), latency_ms=0)
311
+
312
+
313
+ if __name__ == "__main__":
314
+ print(r"""
315
+ ╔════════════════════════════════════════════════════════════════════╗
316
+ β•‘ β•‘
317
+ β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•‘
318
+ β•‘ β–ˆβ–ˆβ•”β•β•β•β•β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•”β•β•β•β•β• β•‘
319
+ β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β–ˆβ–ˆβ–ˆβ–ˆβ•”β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•‘
320
+ β•‘ β•šβ•β•β•β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘β•šβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β•β• β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•”β•β•β• β•‘
321
+ β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β•šβ•β• β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•‘
322
+ β•‘ β•šβ•β•β•β•β•β•β•β•šβ•β•β•šβ•β• β•šβ•β•β•šβ•β• β•šβ•β•β•β•β•β•β•β•šβ•β•β•β•β•β•β• β•‘
323
+ β•‘ β•‘
324
+ β•‘ SimpleTool vLLM-Server v2.0 β•‘
325
+ β•‘ Multi-Head Parallel Decoding β€” v1/v2 Compatible β•‘
326
+ β•‘ β•‘
327
+ β•‘ Run Demos: Open demos/*.html in browser β•‘
328
+ β•‘ Build New: Send simpletool-game-guide.md to AI(Claude Gemini...) β•‘
329
+ β•‘ for Building new your own HTML games easily β•‘
330
+ β•‘ Endpoints: β•‘
331
+ β•‘ GET /health - Health check (+ version info) β•‘
332
+ β•‘ POST /v1/function_call - Function call API (v1 & v2) β•‘
333
+ β•‘ β•‘
334
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
335
+ """)
336
+ uvicorn.run(app, host=SERVER_HOST, port=SERVER_PORT)