deepseekapi1 / app.py
ahmadalfakeh's picture
Update app.py
13d5145 verified
# app.py
import os
import time
import uuid
import re
import json
from typing import Optional, Any, List, Literal
from fastapi import FastAPI, Header, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
APP_TITLE = "HF OpenAI-Compatible API (OpenCode+n8n) fast SSE + plain stream"
# -----------------------
# GGUF model
# -----------------------
HF_REPO_ID = os.environ.get("HF_REPO_ID", "bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF")
HF_FILENAME = os.environ.get("HF_FILENAME", "DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf")
# DeepSeek chat tokens
BOS = "<|begin▁of▁sentence|>"
USR = "<|User|>"
AST = "<|Assistant|>"
PRIMARY_MODEL_ID = os.environ.get("PRIMARY_MODEL_ID", "deepseek-r1-distill-qwen-1.5b-q4_k_m")
# -----------------------
# Auth
# -----------------------
API_KEY = os.environ.get("API_KEY", "").strip()
if not API_KEY:
raise RuntimeError("Missing API_KEY secret (Space Settings -> Secrets).")
def require_auth(auth: Optional[str]) -> None:
if not auth or not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing Authorization: Bearer <key>")
token = auth.removeprefix("Bearer ").strip()
if token != API_KEY:
raise HTTPException(status_code=403, detail="Forbidden")
def oai_error(message: str, code: str = "internal_error", status: int = 500):
return JSONResponse(
status_code=status,
content={"error": {"message": message, "type": "server_error", "code": code}},
)
# -----------------------
# Performance tuning (HF free CPU defaults)
# -----------------------
# Recommended for OpenCode+n8n:
# N_CTX=2048 (or 3072 if you can afford more latency)
# N_BATCH=512 or 1024 (try 1024 first)
N_THREADS = int(os.environ.get("N_THREADS", "2"))
N_CTX = int(os.environ.get("N_CTX", "3072"))
N_BATCH = int(os.environ.get("N_BATCH", "1024"))
MAX_TOKENS_DEFAULT = int(os.environ.get("MAX_TOKENS_DEFAULT", "256"))
CTX_MARGIN = int(os.environ.get("CTX_MARGIN", "96"))
# SSE chunking knobs (make OpenCode feel fast)
SSE_FLUSH_CHARS = int(os.environ.get("SSE_FLUSH_CHARS", "48")) # flush after ~48 chars buffered
SSE_FLUSH_SEC = float(os.environ.get("SSE_FLUSH_SEC", "0.12")) # or flush after 120ms
# History trimming knobs (keeps prompts smaller => faster)
KEEP_LAST_TURNS = int(os.environ.get("KEEP_LAST_TURNS", "8")) # keep last 8 non-system messages
# Keep the server default system short for speed; clients can send their own.
DEFAULT_SYSTEM = (
"You are a helpful programming assistant. "
"Answer directly and concisely. "
"No <think>. No reasoning."
)
THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
def strip_think(text: str) -> str:
text = THINK_BLOCK_RE.sub("", text)
if "<think>" in text:
text = text.split("<think>", 1)[0]
return text.strip()
# -----------------------
# App + Model
# -----------------------
app = FastAPI(title=APP_TITLE)
llm: Optional[Llama] = None
MODEL_PATH: Optional[str] = None
LOAD_ERROR: Optional[str] = None
def ensure_model_loaded() -> None:
global llm, MODEL_PATH, LOAD_ERROR
if llm is not None:
return
LOAD_ERROR = None
MODEL_PATH = hf_hub_download(
repo_id=HF_REPO_ID,
filename=HF_FILENAME,
local_dir="/tmp/models",
)
t0 = time.time()
llm = Llama(
model_path=MODEL_PATH,
n_threads=N_THREADS,
n_ctx=N_CTX,
n_batch=N_BATCH,
use_mmap=True,
use_mlock=False,
)
print(f"Model loaded in {time.time() - t0:.1f}s: {MODEL_PATH}")
# Warm-up (reduces first token delay)
try:
_ = llm(f"{BOS}Warmup{USR}hi{AST}", max_tokens=16, temperature=0.0, top_p=1.0)
print("Warm-up completed")
except Exception as e:
print(f"Warm-up failed (ignored): {e}")
@app.on_event("startup")
def startup_event():
global LOAD_ERROR
try:
ensure_model_loaded()
except Exception as e:
LOAD_ERROR = str(e)
print(f"Startup preload failed: {e}")
# -----------------------
# Token counting + clamping (prevents context overflow crashes)
# -----------------------
def prompt_token_count(prompt: str) -> int:
toks = llm.tokenize(prompt.encode("utf-8"))
return len(toks)
def clamp_max_tokens(prompt: str, requested: int) -> int:
pt = prompt_token_count(prompt)
available = max(0, N_CTX - pt - CTX_MARGIN)
return max(1, min(int(requested), int(available)))
# -----------------------
# Lenient schemas (accept extra fields OpenCode/LangChain send)
# -----------------------
class LenientModel(BaseModel):
model_config = {"extra": "allow"}
class ChatMessage(LenientModel):
role: Literal["system", "user", "assistant", "tool"]
content: Optional[str] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
class ChatCompletionsReq(LenientModel):
model: Optional[str] = None
messages: List[ChatMessage]
temperature: Optional[float] = 0.2
top_p: Optional[float] = 0.9
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
stream: Optional[bool] = False
stop: Optional[Any] = None
class ResponsesReq(LenientModel):
model: Optional[str] = None
input: Any = None
temperature: Optional[float] = 0.2
top_p: Optional[float] = 0.9
max_output_tokens: Optional[int] = None
stream: Optional[bool] = False
class CompletionsReq(LenientModel):
model: Optional[str] = None
prompt: Any = None
temperature: Optional[float] = 0.2
top_p: Optional[float] = 0.9
max_tokens: Optional[int] = None
stream: Optional[bool] = False
class GenerateStreamReq(LenientModel):
prompt: str
max_new_tokens: int = 200
temperature: float = 0.2
top_p: float = 0.9
# -----------------------
# Prompt builders
# -----------------------
def trim_messages_for_speed(messages: List[ChatMessage], keep_last_non_system: int = KEEP_LAST_TURNS) -> List[ChatMessage]:
sys = [m for m in messages if m.role == "system"]
other = [m for m in messages if m.role != "system"]
return sys + other[-keep_last_non_system:]
def messages_to_prompt(messages: List[ChatMessage]) -> str:
system_text = ""
convo = ""
for m in messages:
if m.role == "system":
system_text += (m.content or "")
elif m.role == "user":
convo += f"{USR}{m.content or ''}\n"
elif m.role == "assistant":
convo += f"{AST}{m.content or ''}\n"
elif m.role == "tool":
convo += f"{USR}[Tool result]\n{m.content or ''}\n"
if not system_text.strip():
system_text = DEFAULT_SYSTEM
else:
# Prepend short server rules to keep behavior consistent
system_text = DEFAULT_SYSTEM + "\n" + system_text
return f"{BOS}{system_text}\n{convo}{AST}"
def input_to_messages(inp: Any) -> List[ChatMessage]:
if inp is None:
return [ChatMessage(role="user", content="")]
if isinstance(inp, str):
return [ChatMessage(role="user", content=inp)]
if isinstance(inp, list) and inp and isinstance(inp[0], dict):
if inp[0].get("type") == "message":
msgs: List[ChatMessage] = []
for item in inp:
role = item.get("role", "user")
blocks = item.get("content", [])
parts = []
if isinstance(blocks, list):
for b in blocks:
if isinstance(b, dict) and b.get("type") == "text":
parts.append(b.get("text", ""))
msgs.append(ChatMessage(role=role, content="".join(parts)))
return msgs
if "role" in inp[0]:
return [ChatMessage(role=m.get("role", "user"), content=m.get("content", "")) for m in inp]
return [ChatMessage(role="user", content=str(inp))]
# -----------------------
# Endpoints
# -----------------------
@app.get("/")
def root():
return {
"ok": True,
"service": "openai-compatible",
"endpoints": [
"/v1/models",
"/v1/chat/completions",
"/v1/responses",
"/v1/completions",
"/generate_stream",
],
}
@app.get("/health")
def health():
return {
"ok": True,
"model_loaded": llm is not None,
"load_error": LOAD_ERROR,
"model": PRIMARY_MODEL_ID,
"threads": N_THREADS,
"ctx": N_CTX,
"batch": N_BATCH,
"ctx_margin": CTX_MARGIN,
"keep_last_turns": KEEP_LAST_TURNS,
"sse_flush_chars": SSE_FLUSH_CHARS,
"sse_flush_sec": SSE_FLUSH_SEC,
}
@app.get("/v1/models")
def v1_models(authorization: Optional[str] = Header(default=None)):
require_auth(authorization)
return {
"object": "list",
"data": [
{"id": PRIMARY_MODEL_ID, "object": "model", "owned_by": "me"},
{"id": "gpt-4", "object": "model", "owned_by": "me"},
{"id": "gpt-3.5-turbo", "object": "model", "owned_by": "me"},
{"id": "auto", "object": "model", "owned_by": "me"},
],
}
# -----------------------
# /v1/chat/completions (OpenAI + FAST SSE)
# -----------------------
@app.post("/v1/chat/completions")
def v1_chat_completions(req: ChatCompletionsReq, authorization: Optional[str] = Header(default=None)):
try:
require_auth(authorization)
ensure_model_loaded()
# Trim long histories for speed/stability
msgs = trim_messages_for_speed(req.messages, KEEP_LAST_TURNS)
prompt = messages_to_prompt(msgs)
requested = (
req.max_completion_tokens
if req.max_completion_tokens is not None
else (req.max_tokens if req.max_tokens is not None else MAX_TOKENS_DEFAULT)
)
requested = int(requested)
max_toks = clamp_max_tokens(prompt, requested)
temperature = float(req.temperature if req.temperature is not None else 0.2)
top_p = float(req.top_p if req.top_p is not None else 0.9)
if req.stream:
stream_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
def sse_gen():
buf: List[str] = []
last_flush = time.time()
def flush():
nonlocal buf, last_flush
if not buf:
return None
text = "".join(buf)
buf = []
last_flush = time.time()
event = {
"id": stream_id,
"object": "chat.completion.chunk",
"created": created,
"model": PRIMARY_MODEL_ID,
"choices": [{"index": 0, "delta": {"content": text}, "finish_reason": None}],
}
return f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
for chunk in llm(
prompt,
max_tokens=max_toks,
temperature=temperature,
top_p=top_p,
stream=True,
):
token = chunk["choices"][0]["text"] or ""
if not token:
continue
# Strip thinking inline
token = THINK_BLOCK_RE.sub("", token)
if "<think>" in token:
token = token.split("<think>", 1)[0]
if not token:
continue
buf.append(token)
# Flush less often to reduce SSE overhead (big speed win)
buf_len = sum(len(x) for x in buf)
if buf_len >= SSE_FLUSH_CHARS or (time.time() - last_flush) >= SSE_FLUSH_SEC:
out = flush()
if out:
yield out
# Final flush
out = flush()
if out:
yield out
final = {
"id": stream_id,
"object": "chat.completion.chunk",
"created": created,
"model": PRIMARY_MODEL_ID,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(final, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(sse_gen(), media_type="text/event-stream")
# Non-stream
out = llm(prompt, max_tokens=max_toks, temperature=temperature, top_p=top_p)
text = strip_think(out["choices"][0]["text"])
created = int(time.time())
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": created,
"model": PRIMARY_MODEL_ID,
"choices": [{"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"}],
}
except HTTPException:
raise
except Exception as e:
return oai_error(str(e), code="internal_error", status=500)
# -----------------------
# /v1/responses (n8n/LangChain; minimal)
# -----------------------
@app.post("/v1/responses")
def v1_responses(req: ResponsesReq, authorization: Optional[str] = Header(default=None)):
try:
require_auth(authorization)
ensure_model_loaded()
messages = input_to_messages(req.input)
messages = trim_messages_for_speed(messages, KEEP_LAST_TURNS)
prompt = messages_to_prompt(messages)
requested = int(req.max_output_tokens if req.max_output_tokens is not None else MAX_TOKENS_DEFAULT)
max_toks = clamp_max_tokens(prompt, requested)
temperature = float(req.temperature if req.temperature is not None else 0.2)
top_p = float(req.top_p if req.top_p is not None else 0.9)
out = llm(prompt, max_tokens=max_toks, temperature=temperature, top_p=top_p)
text = strip_think(out["choices"][0]["text"])
rid = f"resp_{uuid.uuid4().hex}"
created = int(time.time())
return {
"id": rid,
"object": "response",
"created": created,
"model": PRIMARY_MODEL_ID,
"output_text": text,
"output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": text}]}],
}
except HTTPException:
raise
except Exception as e:
return oai_error(str(e), code="internal_error", status=500)
# -----------------------
# /v1/completions (legacy)
# -----------------------
@app.post("/v1/completions")
def v1_completions(req: CompletionsReq, authorization: Optional[str] = Header(default=None)):
try:
require_auth(authorization)
ensure_model_loaded()
prompt_in = req.prompt
if isinstance(prompt_in, list):
prompt_in = "\n".join(str(x) for x in prompt_in)
if prompt_in is None:
prompt_in = ""
prompt = f"{BOS}{DEFAULT_SYSTEM}\n{USR}{prompt_in}\n{AST}"
requested = int(req.max_tokens if req.max_tokens is not None else MAX_TOKENS_DEFAULT)
max_toks = clamp_max_tokens(prompt, requested)
temperature = float(req.temperature if req.temperature is not None else 0.2)
top_p = float(req.top_p if req.top_p is not None else 0.9)
if req.stream:
comp_id = f"cmpl-{uuid.uuid4().hex}"
created = int(time.time())
def sse_gen():
buf: List[str] = []
last_flush = time.time()
def flush():
nonlocal buf, last_flush
if not buf:
return None
text = "".join(buf)
buf = []
last_flush = time.time()
event = {
"id": comp_id,
"object": "text_completion",
"created": created,
"model": PRIMARY_MODEL_ID,
"choices": [{"index": 0, "text": text, "finish_reason": None}],
}
return f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
for chunk in llm(prompt, max_tokens=max_toks, temperature=temperature, top_p=top_p, stream=True):
token = chunk["choices"][0]["text"] or ""
if not token:
continue
token = THINK_BLOCK_RE.sub("", token)
if "<think>" in token:
token = token.split("<think>", 1)[0]
if not token:
continue
buf.append(token)
buf_len = sum(len(x) for x in buf)
if buf_len >= SSE_FLUSH_CHARS or (time.time() - last_flush) >= SSE_FLUSH_SEC:
out = flush()
if out:
yield out
out = flush()
if out:
yield out
yield "data: [DONE]\n\n"
return StreamingResponse(sse_gen(), media_type="text/event-stream")
out = llm(prompt, max_tokens=max_toks, temperature=temperature, top_p=top_p)
text = strip_think(out["choices"][0]["text"])
created = int(time.time())
return {
"id": f"cmpl-{uuid.uuid4().hex}",
"object": "text_completion",
"created": created,
"model": PRIMARY_MODEL_ID,
"choices": [{"index": 0, "text": text, "finish_reason": "stop"}],
}
except HTTPException:
raise
except Exception as e:
return oai_error(str(e), code="internal_error", status=500)
# -----------------------
# /generate_stream (plain text streaming; fastest)
# -----------------------
@app.post("/generate_stream")
def generate_stream(req: GenerateStreamReq, authorization: Optional[str] = Header(default=None)):
try:
require_auth(authorization)
ensure_model_loaded()
prompt = f"{BOS}{DEFAULT_SYSTEM}\n{USR}{req.prompt}\n{AST}"
requested = int(req.max_new_tokens if req.max_new_tokens is not None else 200)
max_toks = clamp_max_tokens(prompt, requested)
temperature = float(req.temperature if req.temperature is not None else 0.2)
top_p = float(req.top_p if req.top_p is not None else 0.9)
def token_gen():
for chunk in llm(prompt, max_tokens=max_toks, temperature=temperature, top_p=top_p, stream=True):
token = chunk["choices"][0]["text"] or ""
if not token:
continue
token = THINK_BLOCK_RE.sub("", token)
if "<think>" in token:
token = token.split("<think>", 1)[0]
if token:
yield token
return StreamingResponse(token_gen(), media_type="text/plain")
except HTTPException:
raise
except Exception as e:
return oai_error(str(e), code="internal_error", status=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "7860")))