gpt-chat-api / inference_node.py
fiewolf1000's picture
Update inference_node.py
ea17503 verified
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import os
import logging
import torch
import asyncio
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TextStreamer
)
# 1. 基础配置
logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s")
logger = logging.getLogger("inference_node")
app = FastAPI(title="推理节点服务(Qwen-7B)")
# 2. 模型配置(Qwen-7B 公开模型,无需HF Token)
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-7B")
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN", "") # 留空即可
# 3. 4bit量化配置(适配16G内存,显存占用约4-5GB)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# 4. 加载模型(关键:显式处理tokenizer缺失的配置)
try:
logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化)")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
padding_side="right", # 右侧padding,避免生成时截断
trust_remote_code=True, # Qwen模型必需(加载自定义tokenizer)
eos_token="<|endoftext|>", # 显式指定结束符(兼容旧版本)
pad_token="<|endoftext|>" # 显式指定padding符(避免生成警告)
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto", # 自动分配设备(优先GPU)
token=HF_TOKEN,
trust_remote_code=True,
torch_dtype=torch.bfloat16 # 匹配量化计算精度
)
# 流式输出配置(跳过提示词,只返回生成内容)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 4-5GB(4bit 量化)")
except Exception as e:
logger.error(f"模型加载失败:{str(e)}", exc_info=True)
raise SystemExit(f"服务终止:{str(e)}")
# 5. 请求体定义(用户输入prompt和生成参数)
class NodeInferenceRequest(BaseModel):
prompt: str # 用户提问内容
max_tokens: int = 1024 # 最大生成长度(默认1024)
temperature: float = 0.7 # 随机性(0-1,越大越多样)
# 6. 流式推理接口(核心修复:绕开chat_template,直接构建输入)
@app.post("/node/stream-infer")
async def stream_infer(req: NodeInferenceRequest, request: Request):
try:
# --------------------------
# 关键修复:手动构建Qwen原生对话格式
# Qwen要求格式:<|user|>用户输入<|end|><|assistant|>
# --------------------------
user_prompt = req.prompt.strip()
# 构建模型能理解的输入文本(无需依赖chat_template)
input_text = f"<|user|>{user_prompt}<|end|><|assistant|>"
# 编码输入(转换为模型可处理的张量,并移动到GPU)
inputs = tokenizer(
input_text,
return_tensors="pt", # 返回PyTorch张量
truncation=True, # 截断过长输入(避免OOM)
max_length=2048 # 输入最大长度(根据模型能力调整)
).to(model.device)
# 异步生成流式内容(避免阻塞FastAPI主线程)
async def generate_chunks():
loop = asyncio.get_running_loop()
# 在线程池中运行同步的模型生成(不阻塞事件循环)
outputs = await loop.run_in_executor(
None, # 使用默认线程池
lambda: model.generate(
**inputs,
streamer=streamer, # 流式输出支持
max_new_tokens=req.max_tokens, # 最大生成长度
do_sample=True, # 启用采样(生成多样内容)
temperature=req.temperature, # 随机性控制
pad_token_id=tokenizer.pad_token_id, # padding符ID
eos_token_id=tokenizer.eos_token_id # 结束符ID(生成停止标志)
)
)
# 提取生成的内容(排除输入部分,只取新生成的token)
input_token_len = inputs["input_ids"].shape[1] # 输入token长度
generated_tokens = outputs[0][input_token_len:] # 仅保留新生成的token
# 逐token解码并返回(流式输出核心)
for token in generated_tokens:
# 检查客户端是否断开连接(避免无效生成)
if await request.is_disconnected():
logger.info("客户端已断开连接,停止生成")
break
# 解码单个token(跳过特殊符号,如<|end|>)
token_text = tokenizer.decode(
token,
skip_special_tokens=True, # 跳过特殊token(如结束符、分隔符)
clean_up_tokenization_spaces=True # 清理多余空格
)
# 转义双引号(避免JSON格式错误)
escaped_text = token_text.replace('"', '\\"')
# 按NDJSON格式返回(每行一个JSON对象,兼容流式解析)
yield f'{{"chunk":"{escaped_text}","finish":false}}\n'
# 生成结束标志(告知客户端生成完成)
yield '{"chunk":"","finish":true}\n'
# 返回流式响应(媒体类型为application/x-ndjson,支持逐行解析)
return StreamingResponse(
generate_chunks(),
media_type="application/x-ndjson"
)
except Exception as e:
logger.error(f"推理失败:{str(e)}", exc_info=True) # 记录详细错误堆栈
raise HTTPException(status_code=500, detail=f"推理服务异常:{str(e)}")
# 7. 健康检查接口(用于监控服务状态)
@app.get("/node/health")
async def node_health():
# 检查模型和tokenizer是否正常加载
is_model_ready = model is not None and tokenizer is not None
return {
"status": "healthy" if is_model_ready else "unhealthy",
"model": MODEL_NAME,
"support_stream": True,
"note": "Qwen-7B 4bit量化(适配16G内存),绕开chat_template兼容旧版本",
"timestamp": str(asyncio.get_event_loop().time())
}
# 8. 启动服务(仅在直接运行脚本时执行)
if __name__ == "__main__":
import uvicorn
# 启动UVicorn服务(host=0.0.0.0允许外部访问,port=7860为默认端口)
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info",
workers=1 # 单进程(模型不支持多进程共享,避免重复加载)
)