| import os |
| import time |
| import logging |
| from fastapi import FastAPI, Request, HTTPException, Depends, status |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi.responses import JSONResponse |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
| import gc |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| model = None |
| tokenizer = None |
|
|
| |
| MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" |
| MAX_TOKENS = 512 |
| DEVICE = "cpu" |
|
|
| |
| API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",") |
| API_AUTH_ENABLED = os.getenv("API_AUTH_ENABLED", "true").lower() == "true" |
|
|
| |
| security = HTTPBearer() |
|
|
| def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| """验证 API 密钥""" |
| if not API_AUTH_ENABLED: |
| return True |
| if credentials.scheme != "Bearer": |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid authentication scheme. Use 'Bearer' token", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| api_key = credentials.credentials |
| if api_key not in API_KEYS: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid API key", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| return True |
|
|
| def load_model(): |
| """加载模型""" |
| global model, tokenizer |
| try: |
| logger.info(f"开始加载模型: {MODEL_NAME}") |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_NAME, |
| trust_remote_code=True |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=torch.float16, |
| device_map=None, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True |
| ) |
| model = model.to(DEVICE) |
| model.eval() |
| logger.info("模型加载成功") |
| return True |
| except Exception as e: |
| logger.error(f"模型加载失败: {e}") |
| return False |
|
|
| def apply_chat_template(messages): |
| """将 messages 转换为 Qwen 的对话格式""" |
| text = "" |
| for msg in messages: |
| role = msg.get("role", "").lower() |
| content = msg.get("content", "") |
|
|
| |
| if isinstance(content, list): |
| text_parts = [] |
| for item in content: |
| if isinstance(item, dict): |
| if item.get("type") == "text": |
| text_parts.append(str(item.get("text", ""))) |
| elif isinstance(item, str): |
| text_parts.append(item) |
| content_str = " ".join([p for p in text_parts if p]).strip() |
| else: |
| content_str = str(content).strip() |
|
|
| if not content_str: |
| continue |
|
|
| if role == "system": |
| text += f"<|im_start|>system\n{content_str}<|im_end|>\n" |
| elif role == "user": |
| text += f"<|im_start|>user\n{content_str}<|im_end|>\n" |
| elif role == "assistant": |
| text += f"<|im_start|>assistant\n{content_str}<|im_end|>\n" |
|
|
| text += "<|im_start|>assistant\n" |
| return text |
|
|
| def generate_chat_response(messages, max_tokens=512, temperature=0.7): |
| """生成回复""" |
| if model is None or tokenizer is None: |
| return {"error": "模型未加载"} |
|
|
| try: |
| prompt = apply_chat_template(messages) |
| logger.info(f"输入文本类型: {type(prompt)}, 长度: {len(prompt)}") |
|
|
| inputs = tokenizer( |
| [prompt], |
| return_tensors="pt", |
| truncation=True, |
| max_length=2048, |
| padding=True |
| ) |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=384, |
| do_sample=True, |
| temperature=temperature, |
| top_p=0.85, |
| repetition_penalty=1.05, |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| |
| return {"text": response} |
|
|
| except Exception as e: |
| logger.error(f"生成失败: {str(e)}", exc_info=True) |
| return {"error": str(e)} |
|
|
| |
| app = FastAPI( |
| title="Qwen OpenAI-compatible API", |
| version="1.0", |
| description="仅提供 /v1/chat/completions 端点" |
| ) |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| if load_model(): |
| logger.info("服务启动完成") |
| else: |
| logger.error("模型加载失败,服务可能无法正常工作") |
|
|
| |
| @app.get("/health") |
| async def health_check(): |
| return { |
| "status": "healthy" if model is not None else "model loading failed", |
| "model_loaded": model is not None, |
| "timestamp": int(time.time()) |
| } |
|
|
| |
| @app.get("/") |
| async def root(): |
| return {"message": "Qwen API 服务运行中,仅支持 /v1/chat/completions"} |
|
|
| |
| @app.post("/v1/chat/completions") |
| async def create_chat_completion( |
| request: Request, |
| auth_valid: bool = Depends(verify_api_key) |
| ): |
| try: |
| data = await request.json() |
| messages = data.get("messages", []) |
| max_tokens = data.get("max_tokens", MAX_TOKENS) |
| temperature = data.get("temperature", 0.7) |
|
|
| logger.info(f"收到请求: messages_count={len(messages)}") |
|
|
| if not messages or not isinstance(messages, list): |
| raise ValueError("messages 必须是非空列表") |
|
|
| result = generate_chat_response(messages, max_tokens, temperature) |
|
|
| if "error" in result: |
| raise RuntimeError(result["error"]) |
|
|
| response_data = { |
| "id": f"chatcmpl-{int(time.time()*1000)}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": MODEL_NAME, |
| "choices": [ |
| { |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": result["text"] |
| }, |
| "finish_reason": "stop" |
| } |
| ] |
| } |
|
|
| return response_data |
|
|
| except Exception as e: |
| logger.error(f"Chat Completions 错误: {str(e)}", exc_info=True) |
| return JSONResponse( |
| status_code=500, |
| content={ |
| "error": { |
| "message": str(e), |
| "type": "internal_server_error" |
| } |
| } |
| ) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=7860, |
| workers=1, |
| log_level="info" |
| ) |