|
|
import time
|
|
|
import orjson
|
|
|
import asyncio
|
|
|
from typing import List, AsyncGenerator
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
import os
|
|
|
from fastapi.responses import StreamingResponse, ORJSONResponse
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from pydantic import BaseModel, Field
|
|
|
import httpx
|
|
|
import uvicorn
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
title="Qwen3 API",
|
|
|
description="Streaming API for Qwen3-0.6B model",
|
|
|
version="2.0.0",
|
|
|
default_response_class=ORJSONResponse
|
|
|
)
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
BASE_URL = "http://localhost:8080/v1"
|
|
|
http_client: httpx.AsyncClient = None
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup():
|
|
|
global http_client
|
|
|
http_client = httpx.AsyncClient(
|
|
|
base_url=BASE_URL,
|
|
|
timeout=httpx.Timeout(300.0, connect=10.0),
|
|
|
limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
|
|
|
http2=True
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
async def shutdown():
|
|
|
global http_client
|
|
|
if http_client:
|
|
|
await http_client.aclose()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Message(BaseModel):
|
|
|
role: str
|
|
|
content: str
|
|
|
|
|
|
class Config:
|
|
|
extra = "ignore"
|
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
|
messages: List[Message]
|
|
|
temperature: float = Field(default=0.6, ge=0.0, le=2.0)
|
|
|
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
|
|
|
max_tokens: int = Field(default=4096, ge=1, le=32768)
|
|
|
stream: bool = Field(default=True)
|
|
|
|
|
|
class Config:
|
|
|
extra = "ignore"
|
|
|
|
|
|
|
|
|
class SimpleChatRequest(BaseModel):
|
|
|
prompt: str
|
|
|
temperature: float = Field(default=0.6, ge=0.0, le=2.0)
|
|
|
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
|
|
|
max_tokens: int = Field(default=4096, ge=1, le=32768)
|
|
|
stream: bool = Field(default=True)
|
|
|
|
|
|
class Config:
|
|
|
extra = "ignore"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__slots_parser__ = ['answer', 'thought', 'in_think', 'start_time', 'total_think_time', 'buffer']
|
|
|
|
|
|
class ParserState:
|
|
|
__slots__ = ['answer', 'thought', 'in_think', 'start_time', 'total_think_time']
|
|
|
|
|
|
def __init__(self):
|
|
|
self.answer = []
|
|
|
self.thought = []
|
|
|
self.in_think = False
|
|
|
self.start_time = 0.0
|
|
|
self.total_think_time = 0.0
|
|
|
|
|
|
def get_answer(self) -> str:
|
|
|
return ''.join(self.answer)
|
|
|
|
|
|
def get_thought(self) -> str:
|
|
|
return ''.join(self.thought)
|
|
|
|
|
|
|
|
|
def parse_chunk(content: str, state: ParserState) -> float:
|
|
|
buffer = content
|
|
|
|
|
|
while buffer:
|
|
|
if not state.in_think:
|
|
|
idx = buffer.find('<think>')
|
|
|
if idx != -1:
|
|
|
if idx > 0:
|
|
|
state.answer.append(buffer[:idx])
|
|
|
state.in_think = True
|
|
|
state.start_time = time.perf_counter()
|
|
|
buffer = buffer[idx + 7:]
|
|
|
else:
|
|
|
for i in range(min(6, len(buffer)), 0, -1):
|
|
|
if '<think>'[:i] == buffer[-i:]:
|
|
|
state.answer.append(buffer[:-i])
|
|
|
return 0.0
|
|
|
state.answer.append(buffer)
|
|
|
return 0.0
|
|
|
else:
|
|
|
idx = buffer.find('</think>')
|
|
|
if idx != -1:
|
|
|
if idx > 0:
|
|
|
state.thought.append(buffer[:idx])
|
|
|
state.total_think_time += time.perf_counter() - state.start_time
|
|
|
state.in_think = False
|
|
|
buffer = buffer[idx + 8:]
|
|
|
else:
|
|
|
for i in range(min(7, len(buffer)), 0, -1):
|
|
|
if '</think>'[:i] == buffer[-i:]:
|
|
|
state.thought.append(buffer[:-i])
|
|
|
return time.perf_counter() - state.start_time
|
|
|
state.thought.append(buffer)
|
|
|
return time.perf_counter() - state.start_time
|
|
|
|
|
|
return time.perf_counter() - state.start_time if state.in_think else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def stream_from_backend(messages: list, temperature: float, top_p: float, max_tokens: int) -> AsyncGenerator[str, None]:
|
|
|
payload = {
|
|
|
"model": "",
|
|
|
"messages": messages,
|
|
|
"temperature": temperature,
|
|
|
"top_p": top_p,
|
|
|
"max_tokens": max_tokens,
|
|
|
"stream": True
|
|
|
}
|
|
|
|
|
|
async with http_client.stream(
|
|
|
"POST",
|
|
|
"/chat/completions",
|
|
|
json=payload,
|
|
|
headers={"Accept": "text/event-stream"}
|
|
|
) as response:
|
|
|
async for line in response.aiter_lines():
|
|
|
if line.startswith("data: "):
|
|
|
data = line[6:]
|
|
|
if data == "[DONE]":
|
|
|
break
|
|
|
try:
|
|
|
chunk = orjson.loads(data)
|
|
|
if chunk.get("choices") and chunk["choices"][0].get("delta", {}).get("content"):
|
|
|
yield chunk["choices"][0]["delta"]["content"]
|
|
|
except orjson.JSONDecodeError:
|
|
|
continue
|
|
|
|
|
|
|
|
|
async def generate_stream_fast(request: ChatRequest) -> AsyncGenerator[bytes, None]:
|
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
|
state = ParserState()
|
|
|
chunk_id = f"chatcmpl-{int(time.time() * 1000)}"
|
|
|
created = int(time.time())
|
|
|
|
|
|
try:
|
|
|
async for content in stream_from_backend(
|
|
|
messages, request.temperature, request.top_p, request.max_tokens
|
|
|
):
|
|
|
elapsed = parse_chunk(content, state)
|
|
|
|
|
|
sse_chunk = {
|
|
|
"id": chunk_id,
|
|
|
"object": "chat.completion.chunk",
|
|
|
"created": created,
|
|
|
"model": "qwen3-0.6b",
|
|
|
"choices": [{
|
|
|
"index": 0,
|
|
|
"delta": {"content": content},
|
|
|
"finish_reason": None
|
|
|
}],
|
|
|
"thinking": {
|
|
|
"in_progress": state.in_think,
|
|
|
"elapsed": elapsed if state.in_think else state.total_think_time
|
|
|
}
|
|
|
}
|
|
|
yield b"data: " + orjson.dumps(sse_chunk) + b"\n\n"
|
|
|
|
|
|
final_chunk = {
|
|
|
"id": chunk_id,
|
|
|
"object": "chat.completion.chunk",
|
|
|
"created": created,
|
|
|
"model": "qwen3-0.6b",
|
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
|
"thinking": {
|
|
|
"in_progress": False,
|
|
|
"total_think_time": state.total_think_time,
|
|
|
"thought_content": state.get_thought(),
|
|
|
"answer_content": state.get_answer()
|
|
|
}
|
|
|
}
|
|
|
yield b"data: " + orjson.dumps(final_chunk) + b"\n\n"
|
|
|
yield b"data: [DONE]\n\n"
|
|
|
|
|
|
except Exception as e:
|
|
|
yield b"data: " + orjson.dumps({"error": {"message": str(e)}}) + b"\n\n"
|
|
|
|
|
|
|
|
|
async def generate_complete_fast(request: ChatRequest) -> dict:
|
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
|
state = ParserState()
|
|
|
response_parts = []
|
|
|
|
|
|
try:
|
|
|
async for content in stream_from_backend(
|
|
|
messages, request.temperature, request.top_p, request.max_tokens
|
|
|
):
|
|
|
response_parts.append(content)
|
|
|
parse_chunk(content, state)
|
|
|
|
|
|
full_response = ''.join(response_parts)
|
|
|
|
|
|
return {
|
|
|
"id": f"chatcmpl-{int(time.time() * 1000)}",
|
|
|
"object": "chat.completion",
|
|
|
"created": int(time.time()),
|
|
|
"model": "qwen3-0.6b",
|
|
|
"choices": [{
|
|
|
"index": 0,
|
|
|
"message": {
|
|
|
"role": "assistant",
|
|
|
"content": full_response,
|
|
|
"thinking": {
|
|
|
"thought_content": state.get_thought(),
|
|
|
"answer_content": state.get_answer(),
|
|
|
"total_think_time": state.total_think_time
|
|
|
}
|
|
|
},
|
|
|
"finish_reason": "stop"
|
|
|
}]
|
|
|
}
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
return {"status": "ok", "message": "Qwen3 API is running"}
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health():
|
|
|
try:
|
|
|
response = await http_client.get("/models")
|
|
|
return {"status": "healthy" if response.status_code == 200 else "unhealthy"}
|
|
|
except Exception as e:
|
|
|
return {"status": "unhealthy", "error": str(e)}
|
|
|
|
|
|
|
|
|
@app.get("/v1/models")
|
|
|
async def list_models():
|
|
|
return {
|
|
|
"object": "list",
|
|
|
"data": [{
|
|
|
"id": "qwen3-0.6b",
|
|
|
"object": "model",
|
|
|
"created": int(time.time()),
|
|
|
"owned_by": "local"
|
|
|
}]
|
|
|
}
|
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
|
async def chat_completions(request: ChatRequest):
|
|
|
if request.stream:
|
|
|
return StreamingResponse(
|
|
|
generate_stream_fast(request),
|
|
|
media_type="text/event-stream",
|
|
|
headers={
|
|
|
"Cache-Control": "no-cache",
|
|
|
"Connection": "keep-alive",
|
|
|
"X-Accel-Buffering": "no",
|
|
|
"Transfer-Encoding": "chunked"
|
|
|
}
|
|
|
)
|
|
|
return await generate_complete_fast(request)
|
|
|
|
|
|
|
|
|
@app.post("/chat")
|
|
|
async def simple_chat(request: SimpleChatRequest):
|
|
|
chat_request = ChatRequest(
|
|
|
messages=[Message(role="user", content=request.prompt)],
|
|
|
temperature=request.temperature,
|
|
|
top_p=request.top_p,
|
|
|
max_tokens=request.max_tokens,
|
|
|
stream=request.stream
|
|
|
)
|
|
|
|
|
|
if request.stream:
|
|
|
return StreamingResponse(
|
|
|
generate_stream_fast(chat_request),
|
|
|
media_type="text/event-stream",
|
|
|
headers={
|
|
|
"Cache-Control": "no-cache",
|
|
|
"Connection": "keep-alive",
|
|
|
"X-Accel-Buffering": "no"
|
|
|
}
|
|
|
)
|
|
|
return await generate_complete_fast(chat_request)
|
|
|
|
|
|
|
|
|
async def raw_stream_fast(request: ChatRequest) -> AsyncGenerator[bytes, None]:
|
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
|
|
|
|
try:
|
|
|
async for content in stream_from_backend(
|
|
|
messages, request.temperature, request.top_p, request.max_tokens
|
|
|
):
|
|
|
yield content.encode()
|
|
|
except Exception as e:
|
|
|
yield f"\n\nError: {str(e)}".encode()
|
|
|
|
|
|
|
|
|
@app.post("/chat/raw")
|
|
|
async def raw_chat(request: SimpleChatRequest):
|
|
|
chat_request = ChatRequest(
|
|
|
messages=[Message(role="user", content=request.prompt)],
|
|
|
temperature=request.temperature,
|
|
|
top_p=request.top_p,
|
|
|
max_tokens=request.max_tokens,
|
|
|
stream=True
|
|
|
)
|
|
|
|
|
|
return StreamingResponse(
|
|
|
raw_stream_fast(chat_request),
|
|
|
media_type="text/plain",
|
|
|
headers={
|
|
|
"Cache-Control": "no-cache",
|
|
|
"Connection": "keep-alive",
|
|
|
"X-Accel-Buffering": "no"
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.post("/fast")
|
|
|
async def fast_chat(prompt: str = "", max_tokens: int = 512):
|
|
|
messages = [{"role": "user", "content": prompt}]
|
|
|
response_parts = []
|
|
|
|
|
|
async for content in stream_from_backend(messages, 0.6, 0.95, max_tokens):
|
|
|
response_parts.append(content)
|
|
|
|
|
|
return {"response": ''.join(response_parts)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_CONCURRENT_REQUESTS = int(os.environ.get("MAX_CONCURRENT_REQUESTS", "1"))
|
|
|
|
|
|
|
|
|
current_requests = 0
|
|
|
|
|
|
|
|
|
MINI_SERVER_ID = os.environ.get("MINI_SERVER_ID", "mini-1")
|
|
|
|
|
|
|
|
|
class MiniStatus(BaseModel):
|
|
|
server_id: str
|
|
|
max_concurrent: int
|
|
|
current_requests: int
|
|
|
status: str
|
|
|
|
|
|
|
|
|
@app.get("/status")
|
|
|
async def mini_status():
|
|
|
"""
|
|
|
Used by the main server to know if this mini is idle/busy.
|
|
|
"""
|
|
|
status = "busy" if current_requests >= MAX_CONCURRENT_REQUESTS else "idle"
|
|
|
return MiniStatus(
|
|
|
server_id=MINI_SERVER_ID,
|
|
|
max_concurrent=MAX_CONCURRENT_REQUESTS,
|
|
|
current_requests=current_requests,
|
|
|
status=status,
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.post("/reserve")
|
|
|
async def reserve_slot():
|
|
|
"""
|
|
|
Called by the main server BEFORE it forwards a chat request.
|
|
|
If this mini is full, returns 429 so main server can try another mini.
|
|
|
"""
|
|
|
global current_requests
|
|
|
if current_requests >= MAX_CONCURRENT_REQUESTS:
|
|
|
raise HTTPException(status_code=429, detail="Mini server busy")
|
|
|
current_requests += 1
|
|
|
return {
|
|
|
"server_id": MINI_SERVER_ID,
|
|
|
"current_requests": current_requests,
|
|
|
"max_concurrent": MAX_CONCURRENT_REQUESTS,
|
|
|
}
|
|
|
|
|
|
|
|
|
@app.post("/release")
|
|
|
async def release_slot():
|
|
|
"""
|
|
|
Called by the main server after request is finished (stream closed/response sent).
|
|
|
"""
|
|
|
global current_requests
|
|
|
if current_requests > 0:
|
|
|
current_requests -= 1
|
|
|
return {
|
|
|
"server_id": MINI_SERVER_ID,
|
|
|
"current_requests": current_requests,
|
|
|
"max_concurrent": MAX_CONCURRENT_REQUESTS,
|
|
|
}
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
uvicorn.run(
|
|
|
app,
|
|
|
host="0.0.0.0",
|
|
|
port=7860,
|
|
|
loop="uvloop",
|
|
|
http="httptools",
|
|
|
access_log=False,
|
|
|
workers=1
|
|
|
)
|
|
|
|