File size: 1,911 Bytes
1f5fda3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import uvicorn
import json
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from ziprc import ZIPRCModel, ZIPRCConfig, ZIPRCSampler

# --- Configuration ---
HOST = "0.0.0.0"
PORT = 8000
MODEL_ID = "dataopsnick/Qwen3-4B-Instruct-2507-zip-rc"

# --- Load Model Once ---
print(f"Loading {MODEL_ID}...")
cfg = ZIPRCConfig(model_name=MODEL_ID)
model = ZIPRCModel(cfg)
sampler = ZIPRCSampler(model)
print("Model loaded. Starting server...")

app = FastAPI(title="ZIP-RC OpenAI Compatible API")

@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
    """
    Standard OpenAI Chat Completion endpoint.
    Streams JSON chunks as Server-Sent Events (SSE).
    """
    data = await request.json()
    messages = data.get("messages", [])
    max_tokens = data.get("max_tokens", 512)
    
    # 1. Use the sampler's generator
    stream = sampler.openai(messages, max_tokens=max_tokens)
    
    # 2. Convert to SSE format
    async def sse_generator():
        async for chunk in stream:
            # chunk is an OpenAIObject (dict-like)
            payload = json.dumps(dict(chunk))
            yield f"data: {payload}\n\n"
        yield "data: [DONE]\n\n"

    return StreamingResponse(sse_generator(), media_type="text/event-stream")

if __name__ == "__main__":
    # Use direct Server instantiation to avoid nested-asyncio conflicts in Notebooks
    config = uvicorn.Config(app, host=HOST, port=PORT)
    server = uvicorn.Server(config)
    
    try:
        # Detect if we are already in an event loop (e.g. Colab/Jupyter)
        loop = asyncio.get_running_loop()
        print(f"Server started in background task on http://{HOST}:{PORT}")
        loop.create_task(server.serve())
    except RuntimeError:
        # Standard script execution
        print(f"Server starting on http://{HOST}:{PORT}")
        asyncio.run(server.serve())