deepseek / app.py
han145's picture
Update app.py
617d87f verified
import os
import time
import logging
from fastapi import FastAPI, Request, HTTPException, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gc
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 全局变量
model = None
tokenizer = None
# 配置
MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
MAX_TOKENS = 512
DEVICE = "cpu" # 强制使用 CPU
# API 密钥配置
API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
API_AUTH_ENABLED = os.getenv("API_AUTH_ENABLED", "true").lower() == "true"
# Bearer 认证
security = HTTPBearer()
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""验证 API 密钥"""
if not API_AUTH_ENABLED:
return True
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme. Use 'Bearer' token",
headers={"WWW-Authenticate": "Bearer"},
)
api_key = credentials.credentials
if api_key not in API_KEYS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
return True
def load_model():
"""加载模型"""
global model, tokenizer
try:
logger.info(f"开始加载模型: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map=None,
low_cpu_mem_usage=True,
trust_remote_code=True
)
model = model.to(DEVICE)
model.eval()
logger.info("模型加载成功")
return True
except Exception as e:
logger.error(f"模型加载失败: {e}")
return False
def apply_chat_template(messages):
"""将 messages 转换为 Qwen 的对话格式"""
text = ""
for msg in messages:
role = msg.get("role", "").lower()
content = msg.get("content", "")
# 处理 content 可能是 list 的情况(兼容多模态格式)
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(str(item.get("text", "")))
elif isinstance(item, str):
text_parts.append(item)
content_str = " ".join([p for p in text_parts if p]).strip()
else:
content_str = str(content).strip()
if not content_str:
continue
if role == "system":
text += f"<|im_start|>system\n{content_str}<|im_end|>\n"
elif role == "user":
text += f"<|im_start|>user\n{content_str}<|im_end|>\n"
elif role == "assistant":
text += f"<|im_start|>assistant\n{content_str}<|im_end|>\n"
text += "<|im_start|>assistant\n"
return text
def generate_chat_response(messages, max_tokens=512, temperature=0.7):
"""生成回复"""
if model is None or tokenizer is None:
return {"error": "模型未加载"}
try:
prompt = apply_chat_template(messages)
logger.info(f"输入文本类型: {type(prompt)}, 长度: {len(prompt)}")
inputs = tokenizer(
[prompt],
return_tensors="pt",
truncation=True,
max_length=2048, # 改小,防止上下文过长影响生成
padding=True
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=384, # 暂时写死为 384,确保有足够生成空间
do_sample=True,
temperature=temperature,
top_p=0.85,
repetition_penalty=1.05,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return {"text": response}
except Exception as e:
logger.error(f"生成失败: {str(e)}", exc_info=True)
return {"error": str(e)}
# FastAPI 应用
app = FastAPI(
title="Qwen OpenAI-compatible API",
version="1.0",
description="仅提供 /v1/chat/completions 端点"
)
@app.on_event("startup")
async def startup_event():
if load_model():
logger.info("服务启动完成")
else:
logger.error("模型加载失败,服务可能无法正常工作")
# 健康检查
@app.get("/health")
async def health_check():
return {
"status": "healthy" if model is not None else "model loading failed",
"model_loaded": model is not None,
"timestamp": int(time.time())
}
# 根路径
@app.get("/")
async def root():
return {"message": "Qwen API 服务运行中,仅支持 /v1/chat/completions"}
# 核心端点
@app.post("/v1/chat/completions")
async def create_chat_completion(
request: Request,
auth_valid: bool = Depends(verify_api_key)
):
try:
data = await request.json()
messages = data.get("messages", [])
max_tokens = data.get("max_tokens", MAX_TOKENS)
temperature = data.get("temperature", 0.7)
logger.info(f"收到请求: messages_count={len(messages)}")
if not messages or not isinstance(messages, list):
raise ValueError("messages 必须是非空列表")
result = generate_chat_response(messages, max_tokens, temperature)
if "error" in result:
raise RuntimeError(result["error"])
response_data = {
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_NAME,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": result["text"]
},
"finish_reason": "stop"
}
]
}
return response_data
except Exception as e:
logger.error(f"Chat Completions 错误: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": {
"message": str(e),
"type": "internal_server_error"
}
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
workers=1,
log_level="info"
)