| """ |
| FastAPI server example for text generation using SGLang Engine and demonstrating client usage. |
| |
| Starts the server, sends requests to it, and prints responses. |
| |
| Usage: |
| python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000 [--startup-timeout 60] |
| """ |
|
|
| import os |
| import subprocess |
| import time |
| from contextlib import asynccontextmanager |
|
|
| import requests |
| from fastapi import FastAPI, Request |
|
|
| import sglang as sgl |
| from sglang.utils import terminate_process |
|
|
| engine = None |
|
|
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Manages SGLang engine initialization during server startup.""" |
| global engine |
| |
| |
| print("Loading SGLang engine...") |
| engine = sgl.Engine( |
| model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE")) |
| ) |
| print("SGLang engine loaded.") |
| yield |
| |
| print("Shutting down SGLang engine...") |
| |
| print("SGLang engine shutdown.") |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
|
|
| @app.post("/generate") |
| async def generate_text(request: Request): |
| """FastAPI endpoint to handle text generation requests.""" |
| global engine |
| if not engine: |
| return {"error": "Engine not initialized"}, 503 |
|
|
| try: |
| data = await request.json() |
| prompt = data.get("prompt") |
| max_new_tokens = data.get("max_new_tokens", 128) |
| temperature = data.get("temperature", 0.7) |
|
|
| if not prompt: |
| return {"error": "Prompt is required"}, 400 |
|
|
| |
| state = await engine.async_generate( |
| prompt, |
| sampling_params={ |
| "max_new_tokens": max_new_tokens, |
| "temperature": temperature, |
| }, |
| |
| ) |
|
|
| return {"generated_text": state["text"]} |
| except Exception as e: |
| return {"error": str(e)}, 500 |
|
|
|
|
| |
| def start_server(args, timeout=60): |
| """Starts the Uvicorn server as a subprocess and waits for it to be ready.""" |
| base_url = f"http://{args.host}:{args.port}" |
| command = [ |
| "python", |
| "-m", |
| "uvicorn", |
| "fastapi_engine_inference:app", |
| f"--host={args.host}", |
| f"--port={args.port}", |
| ] |
|
|
| process = subprocess.Popen(command, stdout=None, stderr=None) |
|
|
| start_time = time.perf_counter() |
| with requests.Session() as session: |
| while time.perf_counter() - start_time < timeout: |
| try: |
| |
| response = session.get( |
| f"{base_url}/docs", timeout=5 |
| ) |
| if response.status_code == 200: |
| print(f"Server {base_url} is ready (responded on /docs)") |
| return process |
| except requests.ConnectionError: |
| |
| pass |
| except requests.Timeout: |
| |
| print(f"Health check to {base_url}/docs timed out, retrying...") |
| pass |
| except requests.RequestException as e: |
| |
| print(f"Health check request error: {e}, retrying...") |
| pass |
| |
| time.sleep(1) |
|
|
| |
| |
| if process: |
| print( |
| "Server failed to start within timeout, attempting to terminate process..." |
| ) |
| terminate_process(process) |
| raise TimeoutError( |
| f"Server failed to start at {base_url} within the timeout period." |
| ) |
|
|
|
|
| def send_requests(server_url, prompts, max_new_tokens, temperature): |
| """Sends generation requests to the running server for a list of prompts.""" |
| |
| for i, prompt in enumerate(prompts): |
| print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'") |
| payload = { |
| "prompt": prompt, |
| "max_new_tokens": max_new_tokens, |
| "temperature": temperature, |
| } |
|
|
| try: |
| response = requests.post(f"{server_url}/generate", json=payload, timeout=60) |
|
|
| result = response.json() |
|
|
| print(f"Prompt: {prompt}\nResponse: {result['generated_text']}") |
|
|
| except requests.exceptions.Timeout: |
| print(f" Error: Request timed out for prompt '{prompt}'") |
| except requests.exceptions.RequestException as e: |
| print(f" Error sending request for prompt '{prompt}': {e}") |
|
|
|
|
| if __name__ == "__main__": |
| """Main entry point for the script.""" |
|
|
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", type=str, default="127.0.0.1") |
| parser.add_argument("--port", type=int, default=8000) |
| parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct") |
| parser.add_argument("--tp_size", type=int, default=1) |
| parser.add_argument( |
| "--startup-timeout", |
| type=int, |
| default=60, |
| help="Time in seconds to wait for the server to be ready (default: %(default)s)", |
| ) |
| args = parser.parse_args() |
|
|
| |
| os.environ["MODEL_PATH"] = args.model_path |
| os.environ["TP_SIZE"] = str(args.tp_size) |
|
|
| |
| process = start_server(args, timeout=args.startup_timeout) |
|
|
| |
| prompts = [ |
| "Hello, my name is", |
| "The president of the United States is", |
| "The capital of France is", |
| "The future of AI is", |
| ] |
| max_new_tokens = 64 |
| temperature = 0.1 |
|
|
| |
| server_url = f"http://{args.host}:{args.port}" |
|
|
| |
| send_requests(server_url, prompts, max_new_tokens, temperature) |
|
|
| |
| terminate_process(process) |
|
|