File size: 2,345 Bytes
19c7563 44159b9 19c7563 44159b9 19c7563 44159b9 19c7563 44159b9 19c7563 44159b9 19c7563 44159b9 19c7563 44159b9 19c7563 44159b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import os
import warnings
# 屏蔽 Pydantic 弃用警告(可选,保持日志清洁)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic")
from llama_cpp import Llama
# Load model
MODEL_PATH = "/app/models/Qwen2.5-3B-Instruct-Q4_K_M.gguf"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model not found at {MODEL_PATH}")
llm = Llama(
model_path=MODEL_PATH,
n_ctx=32768, # 🔥 改为 32K 上下文
n_threads=4,
chat_format="chatml",
hf_pretrained_model_name_or_path="Qwen/Qwen2.5-3B-Instruct",
verbose=False,
)
app = FastAPI(title="Qwen2.5-3B API (32K)", version="0.2.0")
class Message(BaseModel):
role: str = Field(..., description="Role: 'system', 'user', or 'assistant'")
content: str = Field(..., description="Message content")
class ChatRequest(BaseModel):
model: str = Field(..., description="Model identifier (ignored, single model)")
messages: List[Message] = Field(..., description="List of messages")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
stream: Optional[bool] = Field(False, description="Stream response (not supported)")
@app.post("/v1/chat/completions")
async def chat_completion(req: ChatRequest):
"""
兼容 OpenAI 格式的 Chat Completions 端点。
注意:此 3B 模型即使上下文设为 32K,在处理长上下文时生成质量可能受限。
"""
try:
# 使用 model_dump() 替代已弃用的 dict(),消除 Pydantic 警告
messages_list = [m.model_dump() for m in req.messages]
result = llm.create_chat_completion(
messages=messages_list,
max_tokens=req.max_tokens,
stream=req.stream,
)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/healthz")
async def healthz():
return {"status": "ok", "model": "Qwen2.5-3B-Instruct", "n_ctx": 32768}
# For HF Spaces compatibility
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port) |