import os import time import logging from fastapi import FastAPI, Request, HTTPException, Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.responses import JSONResponse from transformers import AutoTokenizer, AutoModelForCausalLM import torch import gc # 日志配置 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # 全局变量 model = None tokenizer = None # 配置 MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" MAX_TOKENS = 512 DEVICE = "cpu" # 强制使用 CPU # API 密钥配置 API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",") API_AUTH_ENABLED = os.getenv("API_AUTH_ENABLED", "true").lower() == "true" # Bearer 认证 security = HTTPBearer() def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): """验证 API 密钥""" if not API_AUTH_ENABLED: return True if credentials.scheme != "Bearer": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication scheme. Use 'Bearer' token", headers={"WWW-Authenticate": "Bearer"}, ) api_key = credentials.credentials if api_key not in API_KEYS: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key", headers={"WWW-Authenticate": "Bearer"}, ) return True def load_model(): """加载模型""" global model, tokenizer try: logger.info(f"开始加载模型: {MODEL_NAME}") tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, device_map=None, low_cpu_mem_usage=True, trust_remote_code=True ) model = model.to(DEVICE) model.eval() logger.info("模型加载成功") return True except Exception as e: logger.error(f"模型加载失败: {e}") return False def apply_chat_template(messages): """将 messages 转换为 Qwen 的对话格式""" text = "" for msg in messages: role = msg.get("role", "").lower() content = msg.get("content", "") # 处理 content 可能是 list 的情况(兼容多模态格式) if isinstance(content, list): text_parts = [] for item in content: if isinstance(item, dict): if item.get("type") == "text": text_parts.append(str(item.get("text", ""))) elif isinstance(item, str): text_parts.append(item) content_str = " ".join([p for p in text_parts if p]).strip() else: content_str = str(content).strip() if not content_str: continue if role == "system": text += f"<|im_start|>system\n{content_str}<|im_end|>\n" elif role == "user": text += f"<|im_start|>user\n{content_str}<|im_end|>\n" elif role == "assistant": text += f"<|im_start|>assistant\n{content_str}<|im_end|>\n" text += "<|im_start|>assistant\n" return text def generate_chat_response(messages, max_tokens=512, temperature=0.7): """生成回复""" if model is None or tokenizer is None: return {"error": "模型未加载"} try: prompt = apply_chat_template(messages) logger.info(f"输入文本类型: {type(prompt)}, 长度: {len(prompt)}") inputs = tokenizer( [prompt], return_tensors="pt", truncation=True, max_length=2048, # 改小,防止上下文过长影响生成 padding=True ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=384, # 暂时写死为 384,确保有足够生成空间 do_sample=True, temperature=temperature, top_p=0.85, repetition_penalty=1.05, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) return {"text": response} except Exception as e: logger.error(f"生成失败: {str(e)}", exc_info=True) return {"error": str(e)} # FastAPI 应用 app = FastAPI( title="Qwen OpenAI-compatible API", version="1.0", description="仅提供 /v1/chat/completions 端点" ) @app.on_event("startup") async def startup_event(): if load_model(): logger.info("服务启动完成") else: logger.error("模型加载失败,服务可能无法正常工作") # 健康检查 @app.get("/health") async def health_check(): return { "status": "healthy" if model is not None else "model loading failed", "model_loaded": model is not None, "timestamp": int(time.time()) } # 根路径 @app.get("/") async def root(): return {"message": "Qwen API 服务运行中,仅支持 /v1/chat/completions"} # 核心端点 @app.post("/v1/chat/completions") async def create_chat_completion( request: Request, auth_valid: bool = Depends(verify_api_key) ): try: data = await request.json() messages = data.get("messages", []) max_tokens = data.get("max_tokens", MAX_TOKENS) temperature = data.get("temperature", 0.7) logger.info(f"收到请求: messages_count={len(messages)}") if not messages or not isinstance(messages, list): raise ValueError("messages 必须是非空列表") result = generate_chat_response(messages, max_tokens, temperature) if "error" in result: raise RuntimeError(result["error"]) response_data = { "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion", "created": int(time.time()), "model": MODEL_NAME, "choices": [ { "index": 0, "message": { "role": "assistant", "content": result["text"] }, "finish_reason": "stop" } ] } return response_data except Exception as e: logger.error(f"Chat Completions 错误: {str(e)}", exc_info=True) return JSONResponse( status_code=500, content={ "error": { "message": str(e), "type": "internal_server_error" } } ) if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, workers=1, log_level="info" )