Spaces:
Sleeping
Sleeping
File size: 6,970 Bytes
91bd68b aa269af 91bd68b ea17503 91bd68b a31d0b8 91bd68b 9908754 a31d0b8 9908754 91bd68b 9908754 91bd68b ea17503 aa269af e95d674 9908754 ea17503 9908754 aa269af 9908754 a31d0b8 9908754 aa269af ea17503 a31d0b8 aa269af a31d0b8 91bd68b 9908754 91bd68b 9908754 91bd68b ea17503 91bd68b 9908754 ea17503 9908754 ea17503 9908754 e95d674 ea17503 e95d674 9908754 ea17503 aa269af ea17503 d5049a2 6b2cd32 ea17503 9908754 91bd68b 9908754 ea17503 91bd68b 9908754 91bd68b ea17503 91bd68b 9908754 91bd68b 9908754 aa269af 9908754 aa269af ea17503 9908754 aa269af 91bd68b 9908754 91bd68b ea17503 9908754 ea17503 9908754 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import os
import logging
import torch
import asyncio
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TextStreamer
)
# 1. 基础配置
logging.basicConfig(level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s")
logger = logging.getLogger("inference_node")
app = FastAPI(title="推理节点服务(Qwen-7B)")
# 2. 模型配置(Qwen-7B 公开模型,无需HF Token)
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen-7B")
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN", "") # 留空即可
# 3. 4bit量化配置(适配16G内存,显存占用约4-5GB)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# 4. 加载模型(关键:显式处理tokenizer缺失的配置)
try:
logger.info(f"开始加载模型:{MODEL_NAME}(4bit量化)")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
padding_side="right", # 右侧padding,避免生成时截断
trust_remote_code=True, # Qwen模型必需(加载自定义tokenizer)
eos_token="<|endoftext|>", # 显式指定结束符(兼容旧版本)
pad_token="<|endoftext|>" # 显式指定padding符(避免生成警告)
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto", # 自动分配设备(优先GPU)
token=HF_TOKEN,
trust_remote_code=True,
torch_dtype=torch.bfloat16 # 匹配量化计算精度
)
# 流式输出配置(跳过提示词,只返回生成内容)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
logger.info(f"模型 {MODEL_NAME} 加载成功!显存占用约 4-5GB(4bit 量化)")
except Exception as e:
logger.error(f"模型加载失败:{str(e)}", exc_info=True)
raise SystemExit(f"服务终止:{str(e)}")
# 5. 请求体定义(用户输入prompt和生成参数)
class NodeInferenceRequest(BaseModel):
prompt: str # 用户提问内容
max_tokens: int = 1024 # 最大生成长度(默认1024)
temperature: float = 0.7 # 随机性(0-1,越大越多样)
# 6. 流式推理接口(核心修复:绕开chat_template,直接构建输入)
@app.post("/node/stream-infer")
async def stream_infer(req: NodeInferenceRequest, request: Request):
try:
# --------------------------
# 关键修复:手动构建Qwen原生对话格式
# Qwen要求格式:<|user|>用户输入<|end|><|assistant|>
# --------------------------
user_prompt = req.prompt.strip()
# 构建模型能理解的输入文本(无需依赖chat_template)
input_text = f"<|user|>{user_prompt}<|end|><|assistant|>"
# 编码输入(转换为模型可处理的张量,并移动到GPU)
inputs = tokenizer(
input_text,
return_tensors="pt", # 返回PyTorch张量
truncation=True, # 截断过长输入(避免OOM)
max_length=2048 # 输入最大长度(根据模型能力调整)
).to(model.device)
# 异步生成流式内容(避免阻塞FastAPI主线程)
async def generate_chunks():
loop = asyncio.get_running_loop()
# 在线程池中运行同步的模型生成(不阻塞事件循环)
outputs = await loop.run_in_executor(
None, # 使用默认线程池
lambda: model.generate(
**inputs,
streamer=streamer, # 流式输出支持
max_new_tokens=req.max_tokens, # 最大生成长度
do_sample=True, # 启用采样(生成多样内容)
temperature=req.temperature, # 随机性控制
pad_token_id=tokenizer.pad_token_id, # padding符ID
eos_token_id=tokenizer.eos_token_id # 结束符ID(生成停止标志)
)
)
# 提取生成的内容(排除输入部分,只取新生成的token)
input_token_len = inputs["input_ids"].shape[1] # 输入token长度
generated_tokens = outputs[0][input_token_len:] # 仅保留新生成的token
# 逐token解码并返回(流式输出核心)
for token in generated_tokens:
# 检查客户端是否断开连接(避免无效生成)
if await request.is_disconnected():
logger.info("客户端已断开连接,停止生成")
break
# 解码单个token(跳过特殊符号,如<|end|>)
token_text = tokenizer.decode(
token,
skip_special_tokens=True, # 跳过特殊token(如结束符、分隔符)
clean_up_tokenization_spaces=True # 清理多余空格
)
# 转义双引号(避免JSON格式错误)
escaped_text = token_text.replace('"', '\\"')
# 按NDJSON格式返回(每行一个JSON对象,兼容流式解析)
yield f'{{"chunk":"{escaped_text}","finish":false}}\n'
# 生成结束标志(告知客户端生成完成)
yield '{"chunk":"","finish":true}\n'
# 返回流式响应(媒体类型为application/x-ndjson,支持逐行解析)
return StreamingResponse(
generate_chunks(),
media_type="application/x-ndjson"
)
except Exception as e:
logger.error(f"推理失败:{str(e)}", exc_info=True) # 记录详细错误堆栈
raise HTTPException(status_code=500, detail=f"推理服务异常:{str(e)}")
# 7. 健康检查接口(用于监控服务状态)
@app.get("/node/health")
async def node_health():
# 检查模型和tokenizer是否正常加载
is_model_ready = model is not None and tokenizer is not None
return {
"status": "healthy" if is_model_ready else "unhealthy",
"model": MODEL_NAME,
"support_stream": True,
"note": "Qwen-7B 4bit量化(适配16G内存),绕开chat_template兼容旧版本",
"timestamp": str(asyncio.get_event_loop().time())
}
# 8. 启动服务(仅在直接运行脚本时执行)
if __name__ == "__main__":
import uvicorn
# 启动UVicorn服务(host=0.0.0.0允许外部访问,port=7860为默认端口)
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info",
workers=1 # 单进程(模型不支持多进程共享,避免重复加载)
) |