Antaram commited on
Commit
d1f62f9
·
verified ·
1 Parent(s): d2efe0a

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +42 -0
  2. app.py +445 -0
  3. start.sh +19 -0
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+
3
+ RUN apt-get update && \
4
+ apt-get install -y \
5
+ build-essential \
6
+ libssl-dev \
7
+ zlib1g-dev \
8
+ libopenblas-dev \
9
+ libomp-dev \
10
+ cmake \
11
+ pkg-config \
12
+ git \
13
+ python3-pip \
14
+ curl \
15
+ libcurl4-openssl-dev \
16
+ wget && \
17
+ rm -rf /var/lib/apt/lists/*
18
+
19
+ RUN pip3 install --upgrade pip && \
20
+ pip3 install openai fastapi uvicorn pydantic orjson httptools
21
+
22
+ RUN pip install httpx[http2]
23
+
24
+ RUN git clone https://github.com/ggerganov/llama.cpp && \
25
+ cd llama.cpp && \
26
+ cmake -B build -S . \
27
+ -DLLAMA_BUILD_SERVER=ON \
28
+ -DGGML_BLAS=ON \
29
+ -DGGML_BLAS_VENDOR=OpenBLAS \
30
+ -DCMAKE_BUILD_TYPE=Release && \
31
+ cmake --build build --config Release --target llama-server -j $(nproc)
32
+
33
+ RUN mkdir -p /models && \
34
+ wget -O /models/model.gguf https://huggingface.co/unsloth/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-UD-Q8_K_XL.gguf
35
+
36
+ COPY app.py /app.py
37
+ COPY start.sh /start.sh
38
+ RUN chmod +x /start.sh
39
+
40
+ EXPOSE 7860
41
+
42
+ CMD ["/start.sh"]
app.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import orjson
3
+ import asyncio
4
+ from typing import List, AsyncGenerator
5
+ from fastapi import FastAPI, HTTPException
6
+ import os
7
+ from fastapi.responses import StreamingResponse, ORJSONResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, Field
10
+ import httpx
11
+ import uvicorn
12
+
13
+ # Use orjson for faster JSON serialization
14
+ app = FastAPI(
15
+ title="Qwen3 API",
16
+ description="Streaming API for Qwen3-0.6B model",
17
+ version="2.0.0",
18
+ default_response_class=ORJSONResponse
19
+ )
20
+
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ # Async HTTP client with connection pooling
30
+ BASE_URL = "http://localhost:8080/v1"
31
+ http_client: httpx.AsyncClient = None
32
+
33
+
34
+ @app.on_event("startup")
35
+ async def startup():
36
+ global http_client
37
+ http_client = httpx.AsyncClient(
38
+ base_url=BASE_URL,
39
+ timeout=httpx.Timeout(300.0, connect=10.0),
40
+ limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
41
+ http2=True
42
+ )
43
+
44
+
45
+ @app.on_event("shutdown")
46
+ async def shutdown():
47
+ global http_client
48
+ if http_client:
49
+ await http_client.aclose()
50
+
51
+
52
+ # ===== Models =====
53
+
54
+ class Message(BaseModel):
55
+ role: str
56
+ content: str
57
+
58
+ class Config:
59
+ extra = "ignore"
60
+
61
+
62
+ class ChatRequest(BaseModel):
63
+ messages: List[Message]
64
+ temperature: float = Field(default=0.6, ge=0.0, le=2.0)
65
+ top_p: float = Field(default=0.95, ge=0.0, le=1.0)
66
+ max_tokens: int = Field(default=4096, ge=1, le=32768)
67
+ stream: bool = Field(default=True)
68
+
69
+ class Config:
70
+ extra = "ignore"
71
+
72
+
73
+ class SimpleChatRequest(BaseModel):
74
+ prompt: str
75
+ temperature: float = Field(default=0.6, ge=0.0, le=2.0)
76
+ top_p: float = Field(default=0.95, ge=0.0, le=1.0)
77
+ max_tokens: int = Field(default=4096, ge=1, le=32768)
78
+ stream: bool = Field(default=True)
79
+
80
+ class Config:
81
+ extra = "ignore"
82
+
83
+
84
+ # ===== Optimized Think Tag Parser =====
85
+
86
+ __slots_parser__ = ['answer', 'thought', 'in_think', 'start_time', 'total_think_time', 'buffer']
87
+
88
+ class ParserState:
89
+ __slots__ = ['answer', 'thought', 'in_think', 'start_time', 'total_think_time']
90
+
91
+ def __init__(self):
92
+ self.answer = []
93
+ self.thought = []
94
+ self.in_think = False
95
+ self.start_time = 0.0
96
+ self.total_think_time = 0.0
97
+
98
+ def get_answer(self) -> str:
99
+ return ''.join(self.answer)
100
+
101
+ def get_thought(self) -> str:
102
+ return ''.join(self.thought)
103
+
104
+
105
+ def parse_chunk(content: str, state: ParserState) -> float:
106
+ buffer = content
107
+
108
+ while buffer:
109
+ if not state.in_think:
110
+ idx = buffer.find('<think>')
111
+ if idx != -1:
112
+ if idx > 0:
113
+ state.answer.append(buffer[:idx])
114
+ state.in_think = True
115
+ state.start_time = time.perf_counter()
116
+ buffer = buffer[idx + 7:]
117
+ else:
118
+ for i in range(min(6, len(buffer)), 0, -1):
119
+ if '<think>'[:i] == buffer[-i:]:
120
+ state.answer.append(buffer[:-i])
121
+ return 0.0
122
+ state.answer.append(buffer)
123
+ return 0.0
124
+ else:
125
+ idx = buffer.find('</think>')
126
+ if idx != -1:
127
+ if idx > 0:
128
+ state.thought.append(buffer[:idx])
129
+ state.total_think_time += time.perf_counter() - state.start_time
130
+ state.in_think = False
131
+ buffer = buffer[idx + 8:]
132
+ else:
133
+ for i in range(min(7, len(buffer)), 0, -1):
134
+ if '</think>'[:i] == buffer[-i:]:
135
+ state.thought.append(buffer[:-i])
136
+ return time.perf_counter() - state.start_time
137
+ state.thought.append(buffer)
138
+ return time.perf_counter() - state.start_time
139
+
140
+ return time.perf_counter() - state.start_time if state.in_think else 0.0
141
+
142
+
143
+ # ===== Async Streaming Functions =====
144
+
145
+ async def stream_from_backend(messages: list, temperature: float, top_p: float, max_tokens: int) -> AsyncGenerator[str, None]:
146
+ payload = {
147
+ "model": "",
148
+ "messages": messages,
149
+ "temperature": temperature,
150
+ "top_p": top_p,
151
+ "max_tokens": max_tokens,
152
+ "stream": True
153
+ }
154
+
155
+ async with http_client.stream(
156
+ "POST",
157
+ "/chat/completions",
158
+ json=payload,
159
+ headers={"Accept": "text/event-stream"}
160
+ ) as response:
161
+ async for line in response.aiter_lines():
162
+ if line.startswith("data: "):
163
+ data = line[6:]
164
+ if data == "[DONE]":
165
+ break
166
+ try:
167
+ chunk = orjson.loads(data)
168
+ if chunk.get("choices") and chunk["choices"][0].get("delta", {}).get("content"):
169
+ yield chunk["choices"][0]["delta"]["content"]
170
+ except orjson.JSONDecodeError:
171
+ continue
172
+
173
+
174
+ async def generate_stream_fast(request: ChatRequest) -> AsyncGenerator[bytes, None]:
175
+ messages = [{"role": m.role, "content": m.content} for m in request.messages]
176
+ state = ParserState()
177
+ chunk_id = f"chatcmpl-{int(time.time() * 1000)}"
178
+ created = int(time.time())
179
+
180
+ try:
181
+ async for content in stream_from_backend(
182
+ messages, request.temperature, request.top_p, request.max_tokens
183
+ ):
184
+ elapsed = parse_chunk(content, state)
185
+
186
+ sse_chunk = {
187
+ "id": chunk_id,
188
+ "object": "chat.completion.chunk",
189
+ "created": created,
190
+ "model": "qwen3-0.6b",
191
+ "choices": [{
192
+ "index": 0,
193
+ "delta": {"content": content},
194
+ "finish_reason": None
195
+ }],
196
+ "thinking": {
197
+ "in_progress": state.in_think,
198
+ "elapsed": elapsed if state.in_think else state.total_think_time
199
+ }
200
+ }
201
+ yield b"data: " + orjson.dumps(sse_chunk) + b"\n\n"
202
+
203
+ final_chunk = {
204
+ "id": chunk_id,
205
+ "object": "chat.completion.chunk",
206
+ "created": created,
207
+ "model": "qwen3-0.6b",
208
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
209
+ "thinking": {
210
+ "in_progress": False,
211
+ "total_think_time": state.total_think_time,
212
+ "thought_content": state.get_thought(),
213
+ "answer_content": state.get_answer()
214
+ }
215
+ }
216
+ yield b"data: " + orjson.dumps(final_chunk) + b"\n\n"
217
+ yield b"data: [DONE]\n\n"
218
+
219
+ except Exception as e:
220
+ yield b"data: " + orjson.dumps({"error": {"message": str(e)}}) + b"\n\n"
221
+
222
+
223
+ async def generate_complete_fast(request: ChatRequest) -> dict:
224
+ messages = [{"role": m.role, "content": m.content} for m in request.messages]
225
+ state = ParserState()
226
+ response_parts = []
227
+
228
+ try:
229
+ async for content in stream_from_backend(
230
+ messages, request.temperature, request.top_p, request.max_tokens
231
+ ):
232
+ response_parts.append(content)
233
+ parse_chunk(content, state)
234
+
235
+ full_response = ''.join(response_parts)
236
+
237
+ return {
238
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
239
+ "object": "chat.completion",
240
+ "created": int(time.time()),
241
+ "model": "qwen3-0.6b",
242
+ "choices": [{
243
+ "index": 0,
244
+ "message": {
245
+ "role": "assistant",
246
+ "content": full_response,
247
+ "thinking": {
248
+ "thought_content": state.get_thought(),
249
+ "answer_content": state.get_answer(),
250
+ "total_think_time": state.total_think_time
251
+ }
252
+ },
253
+ "finish_reason": "stop"
254
+ }]
255
+ }
256
+ except Exception as e:
257
+ raise HTTPException(status_code=500, detail=str(e))
258
+
259
+
260
+ # ===== Endpoints =====
261
+
262
+ @app.get("/")
263
+ async def root():
264
+ return {"status": "ok", "message": "Qwen3 API is running"}
265
+
266
+
267
+ @app.get("/health")
268
+ async def health():
269
+ try:
270
+ response = await http_client.get("/models")
271
+ return {"status": "healthy" if response.status_code == 200 else "unhealthy"}
272
+ except Exception as e:
273
+ return {"status": "unhealthy", "error": str(e)}
274
+
275
+
276
+ @app.get("/v1/models")
277
+ async def list_models():
278
+ return {
279
+ "object": "list",
280
+ "data": [{
281
+ "id": "qwen3-0.6b",
282
+ "object": "model",
283
+ "created": int(time.time()),
284
+ "owned_by": "local"
285
+ }]
286
+ }
287
+
288
+
289
+ @app.post("/v1/chat/completions")
290
+ async def chat_completions(request: ChatRequest):
291
+ if request.stream:
292
+ return StreamingResponse(
293
+ generate_stream_fast(request),
294
+ media_type="text/event-stream",
295
+ headers={
296
+ "Cache-Control": "no-cache",
297
+ "Connection": "keep-alive",
298
+ "X-Accel-Buffering": "no",
299
+ "Transfer-Encoding": "chunked"
300
+ }
301
+ )
302
+ return await generate_complete_fast(request)
303
+
304
+
305
+ @app.post("/chat")
306
+ async def simple_chat(request: SimpleChatRequest):
307
+ chat_request = ChatRequest(
308
+ messages=[Message(role="user", content=request.prompt)],
309
+ temperature=request.temperature,
310
+ top_p=request.top_p,
311
+ max_tokens=request.max_tokens,
312
+ stream=request.stream
313
+ )
314
+
315
+ if request.stream:
316
+ return StreamingResponse(
317
+ generate_stream_fast(chat_request),
318
+ media_type="text/event-stream",
319
+ headers={
320
+ "Cache-Control": "no-cache",
321
+ "Connection": "keep-alive",
322
+ "X-Accel-Buffering": "no"
323
+ }
324
+ )
325
+ return await generate_complete_fast(chat_request)
326
+
327
+
328
+ async def raw_stream_fast(request: ChatRequest) -> AsyncGenerator[bytes, None]:
329
+ messages = [{"role": m.role, "content": m.content} for m in request.messages]
330
+
331
+ try:
332
+ async for content in stream_from_backend(
333
+ messages, request.temperature, request.top_p, request.max_tokens
334
+ ):
335
+ yield content.encode()
336
+ except Exception as e:
337
+ yield f"\n\nError: {str(e)}".encode()
338
+
339
+
340
+ @app.post("/chat/raw")
341
+ async def raw_chat(request: SimpleChatRequest):
342
+ chat_request = ChatRequest(
343
+ messages=[Message(role="user", content=request.prompt)],
344
+ temperature=request.temperature,
345
+ top_p=request.top_p,
346
+ max_tokens=request.max_tokens,
347
+ stream=True
348
+ )
349
+
350
+ return StreamingResponse(
351
+ raw_stream_fast(chat_request),
352
+ media_type="text/plain",
353
+ headers={
354
+ "Cache-Control": "no-cache",
355
+ "Connection": "keep-alive",
356
+ "X-Accel-Buffering": "no"
357
+ }
358
+ )
359
+
360
+
361
+ @app.post("/fast")
362
+ async def fast_chat(prompt: str = "", max_tokens: int = 512):
363
+ messages = [{"role": "user", "content": prompt}]
364
+ response_parts = []
365
+
366
+ async for content in stream_from_backend(messages, 0.6, 0.95, max_tokens):
367
+ response_parts.append(content)
368
+
369
+ return {"response": ''.join(response_parts)}
370
+
371
+ # ===== Mini-server load tracking & coordination endpoints =====
372
+
373
+ # How many concurrent requests this mini should handle
374
+ MAX_CONCURRENT_REQUESTS = int(os.environ.get("MAX_CONCURRENT_REQUESTS", "1"))
375
+
376
+ # In-memory tracking per process
377
+ current_requests = 0
378
+
379
+ # For identification / debugging
380
+ MINI_SERVER_ID = os.environ.get("MINI_SERVER_ID", "mini-1")
381
+
382
+
383
+ class MiniStatus(BaseModel):
384
+ server_id: str
385
+ max_concurrent: int
386
+ current_requests: int
387
+ status: str
388
+
389
+
390
+ @app.get("/status")
391
+ async def mini_status():
392
+ """
393
+ Used by the main server to know if this mini is idle/busy.
394
+ """
395
+ status = "busy" if current_requests >= MAX_CONCURRENT_REQUESTS else "idle"
396
+ return MiniStatus(
397
+ server_id=MINI_SERVER_ID,
398
+ max_concurrent=MAX_CONCURRENT_REQUESTS,
399
+ current_requests=current_requests,
400
+ status=status,
401
+ )
402
+
403
+
404
+ @app.post("/reserve")
405
+ async def reserve_slot():
406
+ """
407
+ Called by the main server BEFORE it forwards a chat request.
408
+ If this mini is full, returns 429 so main server can try another mini.
409
+ """
410
+ global current_requests
411
+ if current_requests >= MAX_CONCURRENT_REQUESTS:
412
+ raise HTTPException(status_code=429, detail="Mini server busy")
413
+ current_requests += 1
414
+ return {
415
+ "server_id": MINI_SERVER_ID,
416
+ "current_requests": current_requests,
417
+ "max_concurrent": MAX_CONCURRENT_REQUESTS,
418
+ }
419
+
420
+
421
+ @app.post("/release")
422
+ async def release_slot():
423
+ """
424
+ Called by the main server after request is finished (stream closed/response sent).
425
+ """
426
+ global current_requests
427
+ if current_requests > 0:
428
+ current_requests -= 1
429
+ return {
430
+ "server_id": MINI_SERVER_ID,
431
+ "current_requests": current_requests,
432
+ "max_concurrent": MAX_CONCURRENT_REQUESTS,
433
+ }
434
+
435
+
436
+ if __name__ == "__main__":
437
+ uvicorn.run(
438
+ app,
439
+ host="0.0.0.0",
440
+ port=7860,
441
+ loop="uvloop",
442
+ http="httptools",
443
+ access_log=False,
444
+ workers=1
445
+ )
start.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd /llama.cpp/build
2
+ ./bin/llama-server \
3
+ --host 0.0.0.0 \
4
+ --port 8080 \
5
+ --model /models/model.gguf \
6
+ --ctx-size 32768 \
7
+ --threads 2 &
8
+
9
+
10
+ echo "Waiting for llama.cpp server..."
11
+ until curl -s "http://localhost:8080/v1/models" >/dev/null 2>&1; do
12
+ sleep 1
13
+ done
14
+ echo "llama.cpp server is ready."
15
+
16
+ # Start FastAPI
17
+ echo "Starting FastAPI server on port 7860..."
18
+ cd /
19
+ python3 -m uvicorn app:app --host 0.0.0.0 --port 7860