File size: 7,173 Bytes
354a3ef 81cd270 a7af7f9 354a3ef 847d3f0 81cd270 c1cd0a0 847d3f0 81cd270 24ea806 847d3f0 8845fce 581ae34 354a3ef c1cd0a0 581ae34 c1cd0a0 a7af7f9 c1cd0a0 a7af7f9 c1cd0a0 a7af7f9 c1cd0a0 a7af7f9 354a3ef c1cd0a0 354a3ef c356580 354a3ef c356580 c1cd0a0 c356580 354a3ef c1cd0a0 354a3ef c356580 354a3ef c1cd0a0 354a3ef 9417203 354a3ef c356580 847d3f0 c1cd0a0 847d3f0 81be34a 847d3f0 dd9f413 c1cd0a0 dd9f413 81be34a dd9f413 c1cd0a0 dd9f413 c1cd0a0 dd9f413 847d3f0 c1cd0a0 847d3f0 dd9f413 847d3f0 dd9f413 847d3f0 81be34a c1cd0a0 847d3f0 c1cd0a0 581ae34 354a3ef 847d3f0 725c631 847d3f0 c1cd0a0 847d3f0 1286390 847d3f0 725c631 847d3f0 1286390 725c631 e00d22a 847d3f0 c1cd0a0 847d3f0 24ea806 69fd688 76880ea 847d3f0 581ae34 847d3f0 640b492 c1cd0a0 a7af7f9 c1cd0a0 a7af7f9 c1cd0a0 a7af7f9 26be9f6 354a3ef c1cd0a0 581ae34 1286390 e00d22a c1cd0a0 e00d22a 1286390 847d3f0 c1cd0a0 e00d22a 1286390 174f98d 847d3f0 c1cd0a0 847d3f0 5643b00 847d3f0 69fd688 847d3f0 174f98d 847d3f0 85e708c 57f70e9 174f98d 57f70e9 174f98d 57f70e9 174f98d 57f70e9 174f98d 847d3f0 85e708c 847d3f0 174f98d 847d3f0 174f98d 847d3f0 8845fce 640b492 354a3ef c1cd0a0 581ae34 c1cd0a0 354a3ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | import os
import time
import logging
from fastapi import FastAPI, Request, HTTPException, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gc
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 全局变量
model = None
tokenizer = None
# 配置
MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
MAX_TOKENS = 512
DEVICE = "cpu" # 强制使用 CPU
# API 密钥配置
API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
API_AUTH_ENABLED = os.getenv("API_AUTH_ENABLED", "true").lower() == "true"
# Bearer 认证
security = HTTPBearer()
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""验证 API 密钥"""
if not API_AUTH_ENABLED:
return True
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme. Use 'Bearer' token",
headers={"WWW-Authenticate": "Bearer"},
)
api_key = credentials.credentials
if api_key not in API_KEYS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
return True
def load_model():
"""加载模型"""
global model, tokenizer
try:
logger.info(f"开始加载模型: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map=None,
low_cpu_mem_usage=True,
trust_remote_code=True
)
model = model.to(DEVICE)
model.eval()
logger.info("模型加载成功")
return True
except Exception as e:
logger.error(f"模型加载失败: {e}")
return False
def apply_chat_template(messages):
"""将 messages 转换为 Qwen 的对话格式"""
text = ""
for msg in messages:
role = msg.get("role", "").lower()
content = msg.get("content", "")
# 处理 content 可能是 list 的情况(兼容多模态格式)
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(str(item.get("text", "")))
elif isinstance(item, str):
text_parts.append(item)
content_str = " ".join([p for p in text_parts if p]).strip()
else:
content_str = str(content).strip()
if not content_str:
continue
if role == "system":
text += f"<|im_start|>system\n{content_str}<|im_end|>\n"
elif role == "user":
text += f"<|im_start|>user\n{content_str}<|im_end|>\n"
elif role == "assistant":
text += f"<|im_start|>assistant\n{content_str}<|im_end|>\n"
text += "<|im_start|>assistant\n"
return text
def generate_chat_response(messages, max_tokens=512, temperature=0.7):
"""生成回复"""
if model is None or tokenizer is None:
return {"error": "模型未加载"}
try:
prompt = apply_chat_template(messages)
logger.info(f"输入文本类型: {type(prompt)}, 长度: {len(prompt)}")
inputs = tokenizer(
[prompt],
return_tensors="pt",
truncation=True,
max_length=2048, # 改小,防止上下文过长影响生成
padding=True
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=384, # 暂时写死为 384,确保有足够生成空间
do_sample=True,
temperature=temperature,
top_p=0.85,
repetition_penalty=1.05,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return {"text": response}
except Exception as e:
logger.error(f"生成失败: {str(e)}", exc_info=True)
return {"error": str(e)}
# FastAPI 应用
app = FastAPI(
title="Qwen OpenAI-compatible API",
version="1.0",
description="仅提供 /v1/chat/completions 端点"
)
@app.on_event("startup")
async def startup_event():
if load_model():
logger.info("服务启动完成")
else:
logger.error("模型加载失败,服务可能无法正常工作")
# 健康检查
@app.get("/health")
async def health_check():
return {
"status": "healthy" if model is not None else "model loading failed",
"model_loaded": model is not None,
"timestamp": int(time.time())
}
# 根路径
@app.get("/")
async def root():
return {"message": "Qwen API 服务运行中,仅支持 /v1/chat/completions"}
# 核心端点
@app.post("/v1/chat/completions")
async def create_chat_completion(
request: Request,
auth_valid: bool = Depends(verify_api_key)
):
try:
data = await request.json()
messages = data.get("messages", [])
max_tokens = data.get("max_tokens", MAX_TOKENS)
temperature = data.get("temperature", 0.7)
logger.info(f"收到请求: messages_count={len(messages)}")
if not messages or not isinstance(messages, list):
raise ValueError("messages 必须是非空列表")
result = generate_chat_response(messages, max_tokens, temperature)
if "error" in result:
raise RuntimeError(result["error"])
response_data = {
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_NAME,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": result["text"]
},
"finish_reason": "stop"
}
]
}
return response_data
except Exception as e:
logger.error(f"Chat Completions 错误: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": {
"message": str(e),
"type": "internal_server_error"
}
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
workers=1,
log_level="info"
) |