import os import requests import time from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse from llama_cpp import Llama from pydantic import BaseModel import uvicorn from typing import Generator import threading # Configuration MODEL_URL = "https://huggingface.co/unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf" # Changed to Q4 for faster inference MODEL_NAME = "DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf" MODEL_DIR = "model" MODEL_PATH = os.path.join(MODEL_DIR, MODEL_NAME) # Create model directory if it doesn't exist os.makedirs(MODEL_DIR, exist_ok=True) # Download the model if it doesn't exist if not os.path.exists(MODEL_PATH): print(f"Downloading model from {MODEL_URL}...") response = requests.get(MODEL_URL, stream=True) if response.status_code == 200: with open(MODEL_PATH, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print("Model downloaded successfully!") else: raise RuntimeError(f"Failed to download model: HTTP {response.status_code}") else: print("Model already exists. Skipping download.") # Initialize FastAPI app = FastAPI( title="DeepSeek-R1 OpenAI-Compatible API", description="Optimized OpenAI-compatible API with streaming support", version="2.0.0" ) # CORS Configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Global model loader with optimized settings print("Loading model with optimized settings...") try: llm = Llama( model_path=MODEL_PATH, n_ctx=1024, # Reduced context window for faster processing n_threads=8, # Increased threads for better CPU utilization n_batch=512, # Larger batch size for improved throughput n_gpu_layers=0, use_mlock=True, # Prevent swapping to disk verbose=False ) print("Model loaded with optimized settings!") except Exception as e: raise RuntimeError(f"Failed to load model: {str(e)}") # Streaming generator def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: float) -> Generator[str, None, None]: start_time = time.time() stream = llm.create_completion( prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stop=[""], stream=True ) for chunk in stream: delta = chunk['choices'][0]['text'] yield f"data: {delta}\n\n" # Early stopping if taking too long if time.time() - start_time > 30: # 30s timeout break # OpenAI-Compatible Request Schema class ChatCompletionRequest(BaseModel): model: str = "DeepSeek-R1-Distill-Qwen-1.5B" messages: list[dict] max_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.9 stream: bool = False # Enhanced root endpoint with performance info @app.get("/", response_class=HTMLResponse) async def root(): return f"""
For private use, please duplicate this space:
1. Click your profile picture in the top-right
2. Select "Duplicate Space"
3. Set visibility to Private
curl -N -X POST "{os.environ.get('SPACE_HOST', 'http://localhost:7860')}/v1/chat/completions" \\
-H "Content-Type: application/json" \\
-d '{{
"messages": [{{"role": "user", "content": "Explain quantum computing"}}],
"stream": true,
"max_tokens": 150
}}'
"""
# Async endpoint handler
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
try:
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in request.messages])
prompt += "\nassistant:"
if request.stream:
return StreamingResponse(
generate_stream(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p
),
media_type="text/event-stream"
)
# Non-streaming response
start_time = time.time()
response = llm(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=[""]
)
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response['choices'][0]['text'].strip()
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(prompt),
"completion_tokens": len(response['choices'][0]['text']),
"total_tokens": len(prompt) + len(response['choices'][0]['text'])
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": True,
"performance_settings": {
"n_threads": llm.params.n_threads,
"n_ctx": llm.params.n_ctx,
"n_batch": llm.params.n_batch
}
}
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
timeout_keep_alive=300 # Keep alive for streaming connections
)