Spaces:
Build error
Build error
| import os | |
| import subprocess | |
| import time | |
| import requests | |
| from fastapi import FastAPI, HTTPException, Request, Depends | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Any | |
| import torch | |
| app = FastAPI(title="Qwen3-VL SGLang API with Auth") | |
| # Configuration | |
| MODEL_ID = "Qwen/Qwen3-VL-8B-Thinking" | |
| SGLANG_PORT = 30000 | |
| SGLANG_HOST = "127.0.0.1" | |
| SGLANG_URL = f"http://{SGLANG_HOST}:{SGLANG_PORT}" | |
| API_KEY = "sk-sheikh545466" | |
| security = HTTPBearer() | |
| def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| if credentials.credentials != API_KEY: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Invalid or missing API Key" | |
| ) | |
| return credentials.credentials | |
| # Global process for SGLang server | |
| sglang_process = None | |
| def start_sglang(): | |
| global sglang_process | |
| if sglang_process is None: | |
| print(f"Starting SGLang server for {MODEL_ID}...") | |
| cmd = [ | |
| "python3", "-m", "sglang.launch_server", | |
| "--model-path", MODEL_ID, | |
| "--host", SGLANG_HOST, | |
| "--port", str(SGLANG_PORT), | |
| "--chat-template", "qwen2-vl", | |
| "--trust-remote-code" | |
| ] | |
| if torch.cuda.device_count() > 1: | |
| cmd.extend(["--tp", str(torch.cuda.device_count())]) | |
| sglang_process = subprocess.Popen(cmd) | |
| max_retries = 60 | |
| for i in range(max_retries): | |
| try: | |
| response = requests.get(f"{SGLANG_URL}/v1/models") | |
| if response.status_code == 200: | |
| print("SGLang server is ready!") | |
| return | |
| except: | |
| pass | |
| print(f"Waiting for SGLang server... ({i+1}/{max_retries})") | |
| time.sleep(10) | |
| raise RuntimeError("SGLang server failed to start within timeout.") | |
| async def startup_event(): | |
| import threading | |
| threading.Thread(target=start_sglang, daemon=True).start() | |
| # Request models | |
| class Message(BaseModel): | |
| role: str | |
| content: Any | |
| class ResponseRequest(BaseModel): | |
| prompt: str | |
| max_tokens: Optional[int] = 512 | |
| temperature: Optional[float] = 0.7 | |
| class MessageRequest(BaseModel): | |
| messages: List[Message] | |
| max_tokens: Optional[int] = 512 | |
| temperature: Optional[float] = 0.7 | |
| async def root(): | |
| return {"message": "Qwen3-VL SGLang API is running. Auth required for /v1/ endpoints."} | |
| async def health(): | |
| try: | |
| resp = requests.get(f"{SGLANG_URL}/v1/models", timeout=2) | |
| if resp.status_code == 200: | |
| return {"status": "healthy", "backend": "sglang"} | |
| except: | |
| pass | |
| return {"status": "starting", "backend": "sglang"} | |
| async def generate_response(request: ResponseRequest, token: str = Depends(verify_token)): | |
| try: | |
| payload = { | |
| "model": MODEL_ID, | |
| "prompt": request.prompt, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature | |
| } | |
| resp = requests.post(f"{SGLANG_URL}/v1/completions", json=payload) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return {"response": data["choices"][0]["text"]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_message(request: MessageRequest, token: str = Depends(verify_token)): | |
| try: | |
| payload = { | |
| "model": MODEL_ID, | |
| "messages": [m.dict() for m in request.messages], | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature | |
| } | |
| resp = requests.post(f"{SGLANG_URL}/v1/chat/completions", json=payload) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return {"message": data["choices"][0]["message"]["content"]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Proxy other OpenAI compatible requests to SGLang | |
| async def proxy_openai(path: str, request: Request, token: str = Depends(verify_token)): | |
| url = f"{SGLANG_URL}/v1/{path}" | |
| method = request.method | |
| headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "authorization"]} | |
| body = await request.body() | |
| try: | |
| resp = requests.request(method, url, headers=headers, data=body, timeout=300) | |
| return resp.json() | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |