| | |
| | """ |
| | SimpleTool vLLM Server - Multi-Head Parallel Decoding for Real-Time Function Calling |
| | Supports both v1 and v2 prompt formats. HTML clients need zero changes. |
| | """ |
| |
|
| | import json |
| | import time |
| | import os |
| | from typing import List, Dict, Any, Optional |
| | from contextlib import asynccontextmanager |
| |
|
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel |
| | import uvicorn |
| |
|
| | from vllm import LLM, SamplingParams |
| |
|
| | |
| | MODEL_PATH = "./models/RT-Qwen3-4B-AWQ-v2" |
| | MODEL_VERSION = "v2" |
| | SERVER_HOST = "0.0.0.0" |
| | SERVER_PORT = 8899 |
| | MAX_HISTORY = 6 |
| |
|
| | os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") |
| |
|
| | |
| | HEAD_TAGS = ["<content>", "<function>", "<arg1>", "<arg2>", "<arg3>", "<arg4>", "<arg5>", "<arg6>"] |
| | STOP_TOKENS = ["<|null|>", "</content>", "</function>", "</arg1>", "</arg2>", "</arg3>", "</arg4>", "</arg5>", "</arg6>", "<|im_end|>"] |
| |
|
| | |
| | V1_SYSTEM_TEMPLATE = """<|im_start|>system |
| | You are a multi-head parallel function calling model. |
| | ## Output Heads |
| | |
| | **Head 0 - <content>**: Natural language response |
| | - Format: <content>response text</content> |
| | |
| | **Head 1 - <function>**: Function names to call |
| | - Format: <function>name</function> |
| | |
| | **Head 2-7 - <arg1>-<arg6>**: Function arguments by position |
| | - Format: <argN>value</argN> |
| | - If Unnecessary: <argN><|null|></argN> |
| | |
| | ## Available Tools: |
| | |
| | {tools_json} |
| | <|im_end|> |
| | """ |
| |
|
| | V1_USER_TEMPLATE = "<|im_start|>user\nenvironment: {env}\nhistory: [{hist}]\n\n{query}<|im_end|>\n<|im_start|>assistant\n" |
| |
|
| | |
| | V2_SYSTEM_TEMPLATE = """<|im_start|>system |
| | {system_prompt} |
| | |
| | ## Available Tools: |
| | |
| | {tools_json} |
| | <|im_end|> |
| | """ |
| |
|
| | V2_USER_TEMPLATE = "<|im_start|>user\nhistory: [{hist}]\n\n{query}<|im_end|>\n<|im_start|>assistant\n" |
| |
|
| | |
| | V2_DEFAULT_SYSTEM = "You are a real-time function calling assistant. Convert user commands into function calls using the available tools." |
| |
|
| |
|
| | |
| | class Message(BaseModel): |
| | role: str |
| | content: str |
| |
|
| |
|
| | class FCRequest(BaseModel): |
| | messages: List[Message] |
| | tools: List[Dict[str, Any]] |
| | |
| | environment: Optional[List[str]] = None |
| | history: Optional[List[str]] = None |
| | |
| | system: Optional[str] = None |
| | |
| | max_tokens: int = 32 |
| | temperature: float = 0.0 |
| | include_content_head: bool = False |
| |
|
| |
|
| | class FCResponse(BaseModel): |
| | success: bool |
| | function: Optional[str] = None |
| | args: Dict[str, Any] = {} |
| | heads: Dict[str, str] = {} |
| | content: Optional[str] = None |
| | latency_ms: float = 0 |
| | error: Optional[str] = None |
| |
|
| |
|
| | |
| | class SimpleToolEngine: |
| | def __init__(self, model_path: str, version: str = "v2"): |
| | self.model_path = model_path |
| | self.version = version |
| | self.llm: Optional[LLM] = None |
| | self.sampling_params = None |
| |
|
| | def initialize(self): |
| | print(f"[SimpleTool] Loading model ({self.version}): {self.model_path}") |
| | self.llm = LLM( |
| | model=self.model_path, |
| | trust_remote_code=True, |
| | enable_prefix_caching=True, |
| | tensor_parallel_size=1, |
| | gpu_memory_utilization=0.8, |
| | max_model_len=4096, |
| | dtype="auto", |
| | ) |
| | self.sampling_params = SamplingParams( |
| | temperature=0.0, |
| | max_tokens=32, |
| | stop=STOP_TOKENS, |
| | include_stop_str_in_output=True |
| | ) |
| | print(f"[SimpleTool] Model loaded! (version={self.version})") |
| | self._warmup() |
| |
|
| | def _warmup(self): |
| | print("[SimpleTool] Warming up...") |
| | dummy_tools = '{"type":"function","function":{"name":"test","parameters":{}}}' |
| | if self.version == "v1": |
| | prefix = V1_SYSTEM_TEMPLATE.format(tools_json=dummy_tools) |
| | prefix += V1_USER_TEMPLATE.format(env="[]", hist="", query="test") |
| | else: |
| | prefix = V2_SYSTEM_TEMPLATE.format(system_prompt=V2_DEFAULT_SYSTEM, tools_json=dummy_tools) |
| | prefix += V2_USER_TEMPLATE.format(hist="", query="test") |
| | prompts = [prefix + tag for tag in HEAD_TAGS[:2]] |
| | self.llm.generate(prompts, self.sampling_params) |
| | print("[SimpleTool] Warmup complete!") |
| |
|
| | def _build_tools_json(self, tools: List[Dict]) -> str: |
| | return "\n".join(json.dumps(t, ensure_ascii=False) for t in tools) |
| |
|
| | def _extract_param_info(self, tools: List[Dict]) -> List[str]: |
| | names = [] |
| | for tool in tools: |
| | func = tool.get("function", {}) |
| | params = func.get("parameters", {}).get("properties", {}) |
| | for name in params.keys(): |
| | if name not in names: |
| | names.append(name) |
| | return names[:6] |
| |
|
| | def _get_max_args(self, tools: List[Dict]) -> int: |
| | max_args = 0 |
| | for tool in tools: |
| | func = tool.get("function", {}) |
| | params = func.get("parameters", {}).get("properties", {}) |
| | max_args = max(max_args, len(params)) |
| | return min(max_args, 6) |
| |
|
| | def _build_prompt(self, request: FCRequest) -> str: |
| | """Build the shared prefix according to version.""" |
| | tools_json = self._build_tools_json(request.tools) |
| |
|
| | |
| | query = "" |
| | for msg in request.messages: |
| | if msg.role == "user": |
| | query = msg.content |
| |
|
| | hist_list = (request.history or [])[-MAX_HISTORY:] |
| | hist_str = ", ".join(hist_list) if hist_list else "" |
| |
|
| | if self.version == "v1": |
| | |
| | env_str = json.dumps(request.environment or [], ensure_ascii=False) |
| | system_part = V1_SYSTEM_TEMPLATE.format(tools_json=tools_json) |
| | user_part = V1_USER_TEMPLATE.format(env=env_str, hist=hist_str, query=query) |
| | else: |
| | |
| | |
| | |
| | system_prompt = request.system or V2_DEFAULT_SYSTEM |
| | system_part = V2_SYSTEM_TEMPLATE.format( |
| | system_prompt=system_prompt, |
| | tools_json=tools_json |
| | ) |
| | |
| | |
| | env_prefix = "" |
| | if request.environment: |
| | env_prefix = "environment: " + json.dumps(request.environment, ensure_ascii=False) + "\n" |
| | user_part = V2_USER_TEMPLATE.format( |
| | hist=hist_str, |
| | query=env_prefix + query |
| | ) |
| |
|
| | return system_part + user_part |
| |
|
| | def call(self, request: FCRequest) -> FCResponse: |
| | start = time.perf_counter() |
| |
|
| | full_prefix = self._build_prompt(request) |
| |
|
| | |
| | max_args = self._get_max_args(request.tools) |
| | active_tags = ["<function>"] + [f"<arg{i}>" for i in range(1, max_args + 1)] |
| | if request.include_content_head: |
| | active_tags = ["<content>"] + active_tags |
| |
|
| | prompts = [full_prefix + tag for tag in active_tags] |
| | outputs = self.llm.generate(prompts, self.sampling_params) |
| |
|
| | latency_ms = (time.perf_counter() - start) * 1000 |
| |
|
| | |
| | heads = {} |
| | head_names = [] |
| | if request.include_content_head: |
| | head_names.append("content") |
| | head_names.append("function") |
| | head_names.extend([f"arg{i}" for i in range(1, max_args + 1)]) |
| |
|
| | for i, output in enumerate(outputs): |
| | text = output.outputs[0].text.strip() |
| | for stop in STOP_TOKENS: |
| | if text.endswith(stop): |
| | text = text[:-len(stop)].strip() |
| | break |
| | heads[head_names[i]] = text |
| |
|
| | func_name = heads.get("function", "").strip() |
| | if not func_name or func_name == "<|null|>": |
| | return FCResponse( |
| | success=False, |
| | heads=heads, |
| | content=heads.get("content"), |
| | latency_ms=latency_ms, |
| | error="No function called" |
| | ) |
| |
|
| | param_names = self._extract_param_info(request.tools) |
| | args = {} |
| | for i, name in enumerate(param_names): |
| | val = heads.get(f"arg{i+1}", "").strip() |
| | if val and val != "<|null|>": |
| | if val.isdigit(): |
| | args[name] = int(val) |
| | elif val.lstrip('-').replace('.', '', 1).isdigit(): |
| | args[name] = float(val) |
| | else: |
| | args[name] = val.lower().strip() |
| |
|
| | return FCResponse( |
| | success=True, |
| | function=func_name, |
| | args=args, |
| | heads=heads, |
| | content=heads.get("content"), |
| | latency_ms=latency_ms |
| | ) |
| |
|
| |
|
| | |
| | engine: Optional[SimpleToolEngine] = None |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | global engine |
| | engine = SimpleToolEngine(MODEL_PATH, version=MODEL_VERSION) |
| | engine.initialize() |
| | yield |
| | print("[Server] Shutdown") |
| |
|
| |
|
| | app = FastAPI(title="SimpleTool Server", version="2.0.0", lifespan=lifespan) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return { |
| | "status": "ok", |
| | "loaded": engine is not None and engine.llm is not None, |
| | "model": MODEL_PATH, |
| | "version": MODEL_VERSION, |
| | } |
| |
|
| |
|
| | @app.post("/v1/function_call", response_model=FCResponse) |
| | async def function_call(request: FCRequest): |
| | if engine is None or engine.llm is None: |
| | raise HTTPException(503, "Model not loaded") |
| | try: |
| | return engine.call(request) |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return FCResponse(success=False, error=str(e), latency_ms=0) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print(r""" |
| | ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| | β β |
| | β βββββββββββββββ βββββββββββ βββ ββββββββ β |
| | β ββββββββββββββββ ββββββββββββββββ ββββββββ β |
| | β βββββββββββββββββββββββββββββββββ ββββββ β |
| | β βββββββββββββββββββββββββββββ βββ ββββββ β |
| | β ββββββββββββββ βββ ββββββ ββββββββββββββββ β |
| | β ββββββββββββββ ββββββ ββββββββββββββββ β |
| | β β |
| | β SimpleTool vLLM-Server v2.0 β |
| | β Multi-Head Parallel Decoding β v1/v2 Compatible β |
| | β β |
| | β Run Demos: Open demos/*.html in browser β |
| | β Build New: Send simpletool-game-guide.md to AI(Claude Gemini...) β |
| | β for Building new your own HTML games easily β |
| | β Endpoints: β |
| | β GET /health - Health check (+ version info) β |
| | β POST /v1/function_call - Function call API (v1 & v2) β |
| | β β |
| | ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| | """) |
| | uvicorn.run(app, host=SERVER_HOST, port=SERVER_PORT) |